diff --git a/tests/__init__.py b/tests/__init__.py
index 4c8633b445..775bec0227 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bd229cf7e9..dbcb13c0a5 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015 - 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -128,7 +127,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token="foobar",
url="a_url",
sender=self.test_user,
- ip_range_whitelist=IPSet(["192.168.0.0/16"]),
+ ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
@@ -147,7 +146,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token="foobar",
url="a_url",
sender=self.test_user,
- ip_range_whitelist=IPSet(["192.168.0.0/16"]),
+ ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = AsyncMock(return_value=None)
diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py
index efa3addf00..5e324595a4 100644
--- a/tests/api/test_errors.py
+++ b/tests/api/test_errors.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -33,14 +32,18 @@ class LimitExceededErrorTestCase(unittest.TestCase):
self.assertIn("needle", err.debug_context)
self.assertNotIn("needle", serialised)
+ # Create a sub-class to avoid mutating the class-level property.
+ class LimitExceededErrorHeaders(LimitExceededError):
+ include_retry_after_header = True
+
def test_limit_exceeded_header(self) -> None:
- err = LimitExceededError(limiter_name="test", retry_after_ms=100)
+ err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "1")
def test_limit_exceeded_rounding(self) -> None:
- err = LimitExceededError(limiter_name="test", retry_after_ms=3001)
+ err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "4")
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 743c52d969..73678b6f33 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -1,9 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index a59e168db1..a24638c9ef 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -116,9 +116,8 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Should raise
with self.assertRaises(LimitExceededError) as context:
self.get_success_or_raise(
- limiter.ratelimit(None, key="test_id", _time_now_s=5), by=0.5
+ limiter.ratelimit(None, key="test_id", _time_now_s=5)
)
-
self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise
@@ -193,7 +192,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context:
self.get_success_or_raise(
- limiter.ratelimit(None, key=("test_id",), _time_now_s=1), by=0.5
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
)
self.assertEqual(context.exception.retry_after_ms, 9000)
diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py
index 9dc20800b2..576323e391 100644
--- a/tests/app/test_homeserver_start.py
+++ b/tests/app/test_homeserver_start.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/appservice/__init__.py b/tests/appservice/__init__.py
index 6a72062b0c..3d833a2e44 100644
--- a/tests/appservice/__init__.py
+++ b/tests/appservice/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 0f19736540..344be6e1de 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 3fa4426638..fd7d089edf 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index a1c7ccdd0b..b6ecacccb5 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
index 587ee42067..3d833a2e44 100644
--- a/tests/config/__init__.py
+++ b/tests/config/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py
index ca8884823c..7af635f5ec 100644
--- a/tests/config/test___main__.py
+++ b/tests/config/test___main__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_appservice.py b/tests/config/test_appservice.py
index e3021b59d8..3a2e268d73 100644
--- a/tests/config/test_appservice.py
+++ b/tests/config/test_appservice.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py
index 678a068481..984365d595 100644
--- a/tests/config/test_background_update.py
+++ b/tests/config/test_background_update.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index edb91bc4d9..41ade50714 100644
--- a/tests/config/test_base.py
+++ b/tests/config/test_base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 631263b5ca..2108b40ff4 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 5ab96e16e1..95b13defba 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 479d2aab91..ef9976fb8d 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
index 713bddeb90..79c10b10a6 100644
--- a/tests/config/test_oauth_delegation.py
+++ b/tests/config/test_oauth_delegation.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index 0d52a96858..8267089298 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
index 7fd6df2f93..16e3e13dd0 100644
--- a/tests/config/test_registration_config.py
+++ b/tests/config/test_registration_config.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index e25f7787f4..574697cfd9 100644
--- a/tests/config/test_room_directory.py
+++ b/tests/config/test_room_directory.py
@@ -17,31 +17,15 @@
# [This file includes modifications made by New Vector Limited]
#
#
-import yaml
-from twisted.test.proto_helpers import MemoryReactor
+import yaml
-import synapse.rest.admin
-import synapse.rest.client.login
-import synapse.rest.client.room
from synapse.config.room_directory import RoomDirectoryConfig
-from synapse.server import HomeServer
-from synapse.util import Clock
from tests import unittest
-from tests.unittest import override_config
-
-class RoomDirectoryConfigTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- servlets = [
- synapse.rest.admin.register_servlets,
- synapse.rest.client.login.register_servlets,
- synapse.rest.client.room.register_servlets,
- ]
+class RoomDirectoryConfigTestCase(unittest.TestCase):
def test_alias_creation_acl(self) -> None:
config = yaml.safe_load(
"""
@@ -183,20 +167,3 @@ class RoomDirectoryConfigTestCase(unittest.HomeserverTestCase):
aliases=["#unofficial_st:example.com", "#blah:example.com"],
)
)
-
- @override_config({"room_list_publication_rules": []})
- def test_room_creation_when_publishing_denied(self) -> None:
- """
- Test that when room publishing is denied via the config that new rooms can
- still be created and that the newly created room is not public.
- """
-
- user = self.register_user("alice", "pass")
- token = self.login("alice", "pass")
- room_id = self.helper.create_room_as(user, is_public=True, tok=token)
-
- res = self.get_success(self.store.get_room(room_id))
- assert res is not None
- is_public, _ = res
-
- self.assertFalse(is_public)
diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index 5c2cefaf5b..6cfdd181f0 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
index 64538a4628..8071850cce 100644
--- a/tests/config/test_util.py
+++ b/tests/config/test_util.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py
index 64c0285d01..7a97246064 100644
--- a/tests/config/test_workers.py
+++ b/tests/config/test_workers.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/config/utils.py b/tests/config/utils.py
index 11140ff979..7842b88e72 100644
--- a/tests/config/utils.py
+++ b/tests/config/utils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/crypto/__init__.py b/tests/crypto/__init__.py
index fcd2134c89..3d833a2e44 100644
--- a/tests/crypto/__init__.py
+++ b/tests/crypto/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index d7b9fb8bc6..f8369f75f7 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3bfaf1c80d..c33b03d37d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2017-2021 The Matrix.org Foundation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py
deleted file mode 100644
index 7fb4d4fa90..0000000000
--- a/tests/events/test_auto_accept_invites.py
+++ /dev/null
@@ -1,657 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2021 The Matrix.org Foundation C.I.C
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import asyncio
-from asyncio import Future
-from http import HTTPStatus
-from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast
-from unittest.mock import Mock
-
-import attr
-from parameterized import parameterized
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
-from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
-from synapse.events.auto_accept_invites import InviteAutoAccepter
-from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
-from synapse.module_api import ModuleApi
-from synapse.rest import admin
-from synapse.rest.client import login, room
-from synapse.server import HomeServer
-from synapse.types import StreamToken, create_requester
-from synapse.util import Clock
-
-from tests.handlers.test_sync import generate_sync_config
-from tests.unittest import (
- FederatingHomeserverTestCase,
- HomeserverTestCase,
- TestCase,
- override_config,
-)
-
-
-class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
- """
- Integration test cases for auto-accepting invites.
- """
-
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver()
- self.handler = hs.get_federation_handler()
- self.store = hs.get_datastores().main
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sync_handler = self.hs.get_sync_handler()
- self.module_api = hs.get_module_api()
-
- @parameterized.expand(
- [
- [False],
- [True],
- ]
- )
- @override_config(
- {
- "auto_accept_invites": {
- "enabled": True,
- },
- }
- )
- def test_auto_accept_invites(self, direct_room: bool) -> None:
- """Test that a user automatically joins a room when invited, if the
- module is enabled.
- """
- # A local user who sends an invite
- inviting_user_id = self.register_user("inviter", "pass")
- inviting_user_tok = self.login("inviter", "pass")
-
- # A local user who receives an invite
- invited_user_id = self.register_user("invitee", "pass")
- self.login("invitee", "pass")
-
- # Create a room and send an invite to the other user
- room_id = self.helper.create_room_as(
- inviting_user_id,
- is_public=False,
- tok=inviting_user_tok,
- )
-
- self.helper.invite(
- room_id,
- inviting_user_id,
- invited_user_id,
- tok=inviting_user_tok,
- extra_data={"is_direct": direct_room},
- )
-
- # Check that the invite receiving user has automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 1)
-
- join_update: JoinedSyncResult = join_updates[0]
- self.assertEqual(join_update.room_id, room_id)
-
- @override_config(
- {
- "auto_accept_invites": {
- "enabled": False,
- },
- }
- )
- def test_module_not_enabled(self) -> None:
- """Test that a user does not automatically join a room when invited,
- if the module is not enabled.
- """
- # A local user who sends an invite
- inviting_user_id = self.register_user("inviter", "pass")
- inviting_user_tok = self.login("inviter", "pass")
-
- # A local user who receives an invite
- invited_user_id = self.register_user("invitee", "pass")
- self.login("invitee", "pass")
-
- # Create a room and send an invite to the other user
- room_id = self.helper.create_room_as(
- inviting_user_id, is_public=False, tok=inviting_user_tok
- )
-
- self.helper.invite(
- room_id,
- inviting_user_id,
- invited_user_id,
- tok=inviting_user_tok,
- )
-
- # Check that the invite receiving user has not automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 0)
-
- @override_config(
- {
- "auto_accept_invites": {
- "enabled": True,
- },
- }
- )
- def test_invite_from_remote_user(self) -> None:
- """Test that an invite from a remote user results in the invited user
- automatically joining the room.
- """
- # A remote user who sends the invite
- remote_server = "otherserver"
- remote_user = "@otheruser:" + remote_server
-
- # A local user who creates the room
- creator_user_id = self.register_user("creator", "pass")
- creator_user_tok = self.login("creator", "pass")
-
- # A local user who receives an invite
- invited_user_id = self.register_user("invitee", "pass")
- self.login("invitee", "pass")
-
- room_id = self.helper.create_room_as(
- room_creator=creator_user_id, tok=creator_user_tok
- )
- room_version = self.get_success(self.store.get_room_version(room_id))
-
- invite_event = event_from_pdu_json(
- {
- "type": EventTypes.Member,
- "content": {"membership": "invite"},
- "room_id": room_id,
- "sender": remote_user,
- "state_key": invited_user_id,
- "depth": 32,
- "prev_events": [],
- "auth_events": [],
- "origin_server_ts": self.clock.time_msec(),
- },
- room_version,
- )
- self.get_success(
- self.handler.on_invite_request(
- remote_server,
- invite_event,
- invite_event.room_version,
- )
- )
-
- # Check that the invite receiving user has automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 1)
-
- join_update: JoinedSyncResult = join_updates[0]
- self.assertEqual(join_update.room_id, room_id)
-
- @parameterized.expand(
- [
- [False, False],
- [True, True],
- ]
- )
- @override_config(
- {
- "auto_accept_invites": {
- "enabled": True,
- "only_for_direct_messages": True,
- },
- }
- )
- def test_accept_invite_direct_message(
- self,
- direct_room: bool,
- expect_auto_join: bool,
- ) -> None:
- """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still
- automatically accepted. Otherwise they are rejected.
- """
- # A local user who sends an invite
- inviting_user_id = self.register_user("inviter", "pass")
- inviting_user_tok = self.login("inviter", "pass")
-
- # A local user who receives an invite
- invited_user_id = self.register_user("invitee", "pass")
- self.login("invitee", "pass")
-
- # Create a room and send an invite to the other user
- room_id = self.helper.create_room_as(
- inviting_user_id,
- is_public=False,
- tok=inviting_user_tok,
- )
-
- self.helper.invite(
- room_id,
- inviting_user_id,
- invited_user_id,
- tok=inviting_user_tok,
- extra_data={"is_direct": direct_room},
- )
-
- if expect_auto_join:
- # Check that the invite receiving user has automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 1)
-
- join_update: JoinedSyncResult = join_updates[0]
- self.assertEqual(join_update.room_id, room_id)
- else:
- # Check that the invite receiving user has not automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 0)
-
- @parameterized.expand(
- [
- [False, True],
- [True, False],
- ]
- )
- @override_config(
- {
- "auto_accept_invites": {
- "enabled": True,
- "only_from_local_users": True,
- },
- }
- )
- def test_accept_invite_local_user(
- self, remote_inviter: bool, expect_auto_join: bool
- ) -> None:
- """Tests that, if the module is configured to only accept invites from local users, invites
- from local users are still automatically accepted. Otherwise they are rejected.
- """
- # A local user who sends an invite
- creator_user_id = self.register_user("inviter", "pass")
- creator_user_tok = self.login("inviter", "pass")
-
- # A local user who receives an invite
- invited_user_id = self.register_user("invitee", "pass")
- self.login("invitee", "pass")
-
- # Create a room and send an invite to the other user
- room_id = self.helper.create_room_as(
- creator_user_id, is_public=False, tok=creator_user_tok
- )
-
- if remote_inviter:
- room_version = self.get_success(self.store.get_room_version(room_id))
-
- # A remote user who sends the invite
- remote_server = "otherserver"
- remote_user = "@otheruser:" + remote_server
-
- invite_event = event_from_pdu_json(
- {
- "type": EventTypes.Member,
- "content": {"membership": "invite"},
- "room_id": room_id,
- "sender": remote_user,
- "state_key": invited_user_id,
- "depth": 32,
- "prev_events": [],
- "auth_events": [],
- "origin_server_ts": self.clock.time_msec(),
- },
- room_version,
- )
- self.get_success(
- self.handler.on_invite_request(
- remote_server,
- invite_event,
- invite_event.room_version,
- )
- )
- else:
- self.helper.invite(
- room_id,
- creator_user_id,
- invited_user_id,
- tok=creator_user_tok,
- )
-
- if expect_auto_join:
- # Check that the invite receiving user has automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 1)
-
- join_update: JoinedSyncResult = join_updates[0]
- self.assertEqual(join_update.room_id, room_id)
- else:
- # Check that the invite receiving user has not automatically joined the room when syncing
- join_updates, _ = sync_join(self, invited_user_id)
- self.assertEqual(len(join_updates), 0)
-
-
-_request_key = 0
-
-
-def generate_request_key() -> SyncRequestKey:
- global _request_key
- _request_key += 1
- return ("request_key", _request_key)
-
-
-def sync_join(
- testcase: HomeserverTestCase,
- user_id: str,
- since_token: Optional[StreamToken] = None,
-) -> Tuple[List[JoinedSyncResult], StreamToken]:
- """Perform a sync request for the given user and return the user join updates
- they've received, as well as the next_batch token.
-
- This method assumes testcase.sync_handler points to the homeserver's sync handler.
-
- Args:
- testcase: The testcase that is currently being run.
- user_id: The ID of the user to generate a sync response for.
- since_token: An optional token to indicate from at what point to sync from.
-
- Returns:
- A tuple containing a list of join updates, and the sync response's
- next_batch token.
- """
- requester = create_requester(user_id)
- sync_config = generate_sync_config(requester.user.to_string())
- sync_result = testcase.get_success(
- testcase.hs.get_sync_handler().wait_for_sync_for_user(
- requester,
- sync_config,
- SyncVersion.SYNC_V2,
- generate_request_key(),
- since_token,
- )
- )
-
- return sync_result.joined, sync_result.next_batch
-
-
-class InviteAutoAccepterInternalTestCase(TestCase):
- """
- Test cases which exercise the internals of the InviteAutoAccepter.
- """
-
- def setUp(self) -> None:
- self.module = create_module()
- self.user_id = "@peter:test"
- self.invitee = "@lesley:test"
- self.remote_invitee = "@thomas:remote"
-
- # We know our module API is a mock, but mypy doesn't.
- self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment]
-
- async def test_accept_invite_with_failures(self) -> None:
- """Tests that receiving an invite for a local user makes the module attempt to
- make the invitee join the room. This test verifies that it works if the call to
- update membership returns exceptions before successfully completing and returning an event.
- """
- invite = MockEvent(
- sender="@inviter:test",
- state_key="@invitee:test",
- type="m.room.member",
- content={"membership": "invite"},
- )
-
- join_event = MockEvent(
- sender="someone",
- state_key="someone",
- type="m.room.member",
- content={"membership": "join"},
- )
- # the first two calls raise an exception while the third call is successful
- self.mocked_update_membership.side_effect = [
- SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
- SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
- make_awaitable(join_event),
- ]
-
- # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
- # EventBase.
- await self.module.on_new_event(event=invite) # type: ignore[arg-type]
-
- await self.retry_assertions(
- self.mocked_update_membership,
- 3,
- sender=invite.state_key,
- target=invite.state_key,
- room_id=invite.room_id,
- new_membership="join",
- )
-
- async def test_accept_invite_failures(self) -> None:
- """Tests that receiving an invite for a local user makes the module attempt to
- make the invitee join the room. This test verifies that if the update_membership call
- fails consistently, _retry_make_join will break the loop after the set number of retries and
- execution will continue.
- """
- invite = MockEvent(
- sender=self.user_id,
- state_key=self.invitee,
- type="m.room.member",
- content={"membership": "invite"},
- )
- self.mocked_update_membership.side_effect = SynapseError(
- HTTPStatus.FORBIDDEN, "Forbidden"
- )
-
- # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
- # EventBase.
- await self.module.on_new_event(event=invite) # type: ignore[arg-type]
-
- await self.retry_assertions(
- self.mocked_update_membership,
- 5,
- sender=invite.state_key,
- target=invite.state_key,
- room_id=invite.room_id,
- new_membership="join",
- )
-
- async def test_not_state(self) -> None:
- """Tests that receiving an invite that's not a state event does nothing."""
- invite = MockEvent(
- sender=self.user_id, type="m.room.member", content={"membership": "invite"}
- )
-
- # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
- # EventBase.
- await self.module.on_new_event(event=invite) # type: ignore[arg-type]
-
- self.mocked_update_membership.assert_not_called()
-
- async def test_not_invite(self) -> None:
- """Tests that receiving a membership update that's not an invite does nothing."""
- invite = MockEvent(
- sender=self.user_id,
- state_key=self.user_id,
- type="m.room.member",
- content={"membership": "join"},
- )
-
- # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
- # EventBase.
- await self.module.on_new_event(event=invite) # type: ignore[arg-type]
-
- self.mocked_update_membership.assert_not_called()
-
- async def test_not_membership(self) -> None:
- """Tests that receiving a state event that's not a membership update does
- nothing.
- """
- invite = MockEvent(
- sender=self.user_id,
- state_key=self.user_id,
- type="org.matrix.test",
- content={"foo": "bar"},
- )
-
- # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
- # EventBase.
- await self.module.on_new_event(event=invite) # type: ignore[arg-type]
-
- self.mocked_update_membership.assert_not_called()
-
- def test_config_parse(self) -> None:
- """Tests that a correct configuration parses."""
- config = {
- "auto_accept_invites": {
- "enabled": True,
- "only_for_direct_messages": True,
- "only_from_local_users": True,
- }
- }
- parsed_config = AutoAcceptInvitesConfig()
- parsed_config.read_config(config)
-
- self.assertTrue(parsed_config.enabled)
- self.assertTrue(parsed_config.accept_invites_only_for_direct_messages)
- self.assertTrue(parsed_config.accept_invites_only_from_local_users)
-
- def test_runs_on_only_one_worker(self) -> None:
- """
- Tests that the module only runs on the specified worker.
- """
- # By default, we run on the main process...
- main_module = create_module(
- config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None
- )
- cast(
- Mock, main_module._api.register_third_party_rules_callbacks
- ).assert_called_once()
-
- # ...and not on other workers (like synchrotrons)...
- sync_module = create_module(worker_name="synchrotron42")
- cast(
- Mock, sync_module._api.register_third_party_rules_callbacks
- ).assert_not_called()
-
- # ...unless we configured them to be the designated worker.
- specified_module = create_module(
- config_override={
- "auto_accept_invites": {
- "enabled": True,
- "worker_to_run_on": "account_data1",
- }
- },
- worker_name="account_data1",
- )
- cast(
- Mock, specified_module._api.register_third_party_rules_callbacks
- ).assert_called_once()
-
- async def retry_assertions(
- self, mock: Mock, call_count: int, **kwargs: Any
- ) -> None:
- """
- This is a hacky way to ensure that the assertions are not called before the other coroutine
- has a chance to call `update_room_membership`. It catches the exception caused by a failure,
- and sleeps the thread before retrying, up until 5 tries.
-
- Args:
- call_count: the number of times the mock should have been called
- mock: the mocked function we want to assert on
- kwargs: keyword arguments to assert that the mock was called with
- """
-
- i = 0
- while i < 5:
- try:
- # Check that the mocked method is called the expected amount of times and with the right
- # arguments to attempt to make the user join the room.
- mock.assert_called_with(**kwargs)
- self.assertEqual(call_count, mock.call_count)
- break
- except AssertionError as e:
- i += 1
- if i == 5:
- # we've used up the tries, force the test to fail as we've already caught the exception
- self.fail(e)
- await asyncio.sleep(1)
-
-
-@attr.s(auto_attribs=True)
-class MockEvent:
- """Mocks an event. Only exposes properties the module uses."""
-
- sender: str
- type: str
- content: Dict[str, Any]
- room_id: str = "!someroom"
- state_key: Optional[str] = None
-
- def is_state(self) -> bool:
- """Checks if the event is a state event by checking if it has a state key."""
- return self.state_key is not None
-
- @property
- def membership(self) -> str:
- """Extracts the membership from the event. Should only be called on an event
- that's a membership event, and will raise a KeyError otherwise.
- """
- membership: str = self.content["membership"]
- return membership
-
-
-T = TypeVar("T")
-TV = TypeVar("TV")
-
-
-async def make_awaitable(value: T) -> T:
- return value
-
-
-def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
- """
- Makes an awaitable, suitable for mocking an `async` function.
- This uses Futures as they can be awaited multiple times so can be returned
- to multiple callers.
- """
- future: Future[TV] = Future()
- future.set_result(result)
- return future
-
-
-def create_module(
- config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None
-) -> InviteAutoAccepter:
- # Create a mock based on the ModuleApi spec, but override some mocked functions
- # because some capabilities are needed for running the tests.
- module_api = Mock(spec=ModuleApi)
- module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
- module_api.worker_name = worker_name
- module_api.sleep.return_value = make_multiple_awaitable(None)
-
- if config_override is None:
- config_override = {}
-
- config = AutoAcceptInvitesConfig()
- config.read_config(config_override)
-
- return InviteAutoAccepter(config, module_api)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index e48983ddfe..ce66241763 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -36,7 +35,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock
-from tests.handlers.test_sync import SyncRequestKey, SyncVersion, generate_sync_config
+from tests.handlers.test_sync import generate_sync_config
from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
@@ -498,15 +497,6 @@ def send_presence_update(
return channel.json_body
-_request_key = 0
-
-
-def generate_request_key() -> SyncRequestKey:
- global _request_key
- _request_key += 1
- return ("request_key", _request_key)
-
-
def sync_presence(
testcase: HomeserverTestCase,
user_id: str,
@@ -530,11 +520,7 @@ def sync_presence(
sync_config = generate_sync_config(requester.user.to_string())
sync_result = testcase.get_success(
testcase.hs.get_sync_handler().wait_for_sync_for_user(
- requester,
- sync_config,
- SyncVersion.SYNC_V2,
- generate_request_key(),
- since_token,
+ requester, sync_config, since_token
)
)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index f96bbe7705..61640c8c2f 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 30f8787758..bf0da95d12 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -32,7 +31,6 @@ from synapse.events.utils import (
PowerLevelsContent,
SerializeEventConfig,
_split_field,
- clone_event,
copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event,
@@ -612,32 +610,6 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
)
-class CloneEventTestCase(stdlib_unittest.TestCase):
- def test_unsigned_is_copied(self) -> None:
- original = make_event_from_dict(
- {
- "type": "A",
- "event_id": "$test:domain",
- "unsigned": {"a": 1, "b": 2},
- },
- RoomVersions.V1,
- {"txn_id": "txn"},
- )
- original.internal_metadata.stream_ordering = 1234
- self.assertEqual(original.internal_metadata.stream_ordering, 1234)
- original.internal_metadata.instance_name = "worker1"
- self.assertEqual(original.internal_metadata.instance_name, "worker1")
-
- cloned = clone_event(original)
- cloned.unsigned["b"] = 3
-
- self.assertEqual(original.unsigned, {"a": 1, "b": 2})
- self.assertEqual(cloned.unsigned, {"a": 1, "b": 3})
- self.assertEqual(cloned.internal_metadata.stream_ordering, 1234)
- self.assertEqual(cloned.internal_metadata.instance_name, "worker1")
- self.assertEqual(cloned.internal_metadata.txn_id, "txn")
-
-
class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
return serialize_event(
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9bd97e5d4e..8a06695ff6 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Foundation
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 585f3b798c..5313bb33a3 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 Matrix.org Federation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py
deleted file mode 100644
index 0dcf20f5f5..0000000000
--- a/tests/federation/test_federation_media.py
+++ /dev/null
@@ -1,258 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import io
-import os
-import shutil
-import tempfile
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.media.filepath import MediaFilePaths
-from synapse.media.media_storage import MediaStorage
-from synapse.media.storage_provider import (
- FileStorageProviderBackend,
- StorageProviderWrapper,
-)
-from synapse.server import HomeServer
-from synapse.types import UserID
-from synapse.util import Clock
-
-from tests import unittest
-from tests.media.test_media_storage import small_png
-from tests.test_utils import SMALL_PNG
-
-
-class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- super().prepare(reactor, clock, hs)
- self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
- self.addCleanup(shutil.rmtree, self.test_dir)
- self.primary_base_path = os.path.join(self.test_dir, "primary")
- self.secondary_base_path = os.path.join(self.test_dir, "secondary")
-
- hs.config.media.media_store_path = self.primary_base_path
-
- storage_providers = [
- StorageProviderWrapper(
- FileStorageProviderBackend(hs, self.secondary_base_path),
- store_local=True,
- store_remote=False,
- store_synchronous=True,
- )
- ]
-
- self.filepaths = MediaFilePaths(self.primary_base_path)
- self.media_storage = MediaStorage(
- hs, self.primary_base_path, self.filepaths, storage_providers
- )
- self.media_repo = hs.get_media_repository()
-
- def test_file_download(self) -> None:
- content = io.BytesIO(b"file_to_stream")
- content_uri = self.get_success(
- self.media_repo.create_content(
- "text/plain",
- "test_upload",
- content,
- 46,
- UserID.from_string("@user_id:whatever.org"),
- )
- )
- # test with a text file
- channel = self.make_signed_federation_request(
- "GET",
- f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
- )
- self.pump()
- self.assertEqual(200, channel.code)
-
- content_type = channel.headers.getRawHeaders("content-type")
- assert content_type is not None
- assert "multipart/mixed" in content_type[0]
- assert "boundary" in content_type[0]
-
- # extract boundary
- boundary = content_type[0].split("boundary=")[1]
- # split on boundary and check that json field and expected value exist
- stripped = channel.text_body.split("\r\n" + "--" + boundary)
- # TODO: the json object expected will change once MSC3911 is implemented, currently
- # {} is returned for all requests as a placeholder (per MSC3196)
- found_json = any(
- "\r\nContent-Type: application/json\r\n\r\n{}" in field
- for field in stripped
- )
- self.assertTrue(found_json)
-
- # check that the text file and expected value exist
- found_file = any(
- "\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream"
- in field
- for field in stripped
- )
- self.assertTrue(found_file)
-
- content = io.BytesIO(SMALL_PNG)
- content_uri = self.get_success(
- self.media_repo.create_content(
- "image/png",
- "test_png_upload",
- content,
- 67,
- UserID.from_string("@user_id:whatever.org"),
- )
- )
- # test with an image file
- channel = self.make_signed_federation_request(
- "GET",
- f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
- )
- self.pump()
- self.assertEqual(200, channel.code)
-
- content_type = channel.headers.getRawHeaders("content-type")
- assert content_type is not None
- assert "multipart/mixed" in content_type[0]
- assert "boundary" in content_type[0]
-
- # extract boundary
- boundary = content_type[0].split("boundary=")[1]
- # split on boundary and check that json field and expected value exist
- body = channel.result.get("body")
- assert body is not None
- stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
- found_json = any(
- b"\r\nContent-Type: application/json\r\n\r\n{}" in field
- for field in stripped_bytes
- )
- self.assertTrue(found_json)
-
- # check that the png file exists and matches what was uploaded
- found_file = any(SMALL_PNG in field for field in stripped_bytes)
- self.assertTrue(found_file)
-
-
-class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- super().prepare(reactor, clock, hs)
- self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
- self.addCleanup(shutil.rmtree, self.test_dir)
- self.primary_base_path = os.path.join(self.test_dir, "primary")
- self.secondary_base_path = os.path.join(self.test_dir, "secondary")
-
- hs.config.media.media_store_path = self.primary_base_path
-
- storage_providers = [
- StorageProviderWrapper(
- FileStorageProviderBackend(hs, self.secondary_base_path),
- store_local=True,
- store_remote=False,
- store_synchronous=True,
- )
- ]
-
- self.filepaths = MediaFilePaths(self.primary_base_path)
- self.media_storage = MediaStorage(
- hs, self.primary_base_path, self.filepaths, storage_providers
- )
- self.media_repo = hs.get_media_repository()
-
- def test_thumbnail_download_scaled(self) -> None:
- content = io.BytesIO(small_png.data)
- content_uri = self.get_success(
- self.media_repo.create_content(
- "image/png",
- "test_png_thumbnail",
- content,
- 67,
- UserID.from_string("@user_id:whatever.org"),
- )
- )
- # test with an image file
- channel = self.make_signed_federation_request(
- "GET",
- f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=scale",
- )
- self.pump()
- self.assertEqual(200, channel.code)
-
- content_type = channel.headers.getRawHeaders("content-type")
- assert content_type is not None
- assert "multipart/mixed" in content_type[0]
- assert "boundary" in content_type[0]
-
- # extract boundary
- boundary = content_type[0].split("boundary=")[1]
- # split on boundary and check that json field and expected value exist
- body = channel.result.get("body")
- assert body is not None
- stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
- found_json = any(
- b"\r\nContent-Type: application/json\r\n\r\n{}" in field
- for field in stripped_bytes
- )
- self.assertTrue(found_json)
-
- # check that the png file exists and matches the expected scaled bytes
- found_file = any(small_png.expected_scaled in field for field in stripped_bytes)
- self.assertTrue(found_file)
-
- def test_thumbnail_download_cropped(self) -> None:
- content = io.BytesIO(small_png.data)
- content_uri = self.get_success(
- self.media_repo.create_content(
- "image/png",
- "test_png_thumbnail",
- content,
- 67,
- UserID.from_string("@user_id:whatever.org"),
- )
- )
- # test with an image file
- channel = self.make_signed_federation_request(
- "GET",
- f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=crop",
- )
- self.pump()
- self.assertEqual(200, channel.code)
-
- content_type = channel.headers.getRawHeaders("content-type")
- assert content_type is not None
- assert "multipart/mixed" in content_type[0]
- assert "boundary" in content_type[0]
-
- # extract boundary
- boundary = content_type[0].split("boundary=")[1]
- # split on boundary and check that json field and expected value exist
- body = channel.result.get("body")
- assert body is not None
- stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
- found_json = any(
- b"\r\nContent-Type: application/json\r\n\r\n{}" in field
- for field in stripped_bytes
- )
- self.assertTrue(found_json)
-
- # check that the png file exists and matches the expected cropped bytes
- found_file = any(
- small_png.expected_cropped in field for field in stripped_bytes
- )
- self.assertTrue(found_file)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 6a8887fe74..9073afc70e 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -27,8 +27,6 @@ from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
-from synapse.api.presence import UserPresenceState
-from synapse.federation.sender.per_destination_queue import MAX_PRESENCE_STATES_PER_EDU
from synapse.federation.units import Transaction
from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
@@ -268,123 +266,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
-class FederationSenderPresenceTestCases(HomeserverTestCase):
- """
- Test federation sending for presence updates.
- """
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.federation_transport_client = Mock(spec=["send_transaction"])
- self.federation_transport_client.send_transaction = AsyncMock()
- hs = self.setup_test_homeserver(
- federation_transport_client=self.federation_transport_client,
- )
-
- return hs
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["federation_sender_instances"] = None
- return config
-
- def test_presence_simple(self) -> None:
- "Test that sending a single presence update works"
-
- mock_send_transaction: AsyncMock = (
- self.federation_transport_client.send_transaction
- )
- mock_send_transaction.return_value = {}
-
- sender = self.hs.get_federation_sender()
- self.get_success(
- sender.send_presence_to_destinations(
- [UserPresenceState.default("@user:test")],
- ["server"],
- )
- )
-
- self.pump()
-
- # expect a call to send_transaction
- mock_send_transaction.assert_awaited_once()
-
- json_cb = mock_send_transaction.call_args[0][1]
- data = json_cb()
- self.assertEqual(
- data["edus"],
- [
- {
- "edu_type": EduTypes.PRESENCE,
- "content": {
- "push": [
- {
- "presence": "offline",
- "user_id": "@user:test",
- }
- ]
- },
- }
- ],
- )
-
- def test_presence_batched(self) -> None:
- """Test that sending lots of presence updates to a destination are
- batched, rather than having them all sent in one EDU."""
-
- mock_send_transaction: AsyncMock = (
- self.federation_transport_client.send_transaction
- )
- mock_send_transaction.return_value = {}
-
- sender = self.hs.get_federation_sender()
-
- # We now send lots of presence updates to force the federation sender to
- # batch the mup.
- number_presence_updates_to_send = MAX_PRESENCE_STATES_PER_EDU * 2
- self.get_success(
- sender.send_presence_to_destinations(
- [
- UserPresenceState.default(f"@user{i}:test")
- for i in range(number_presence_updates_to_send)
- ],
- ["server"],
- )
- )
-
- self.pump()
-
- # We should have seen at least one transcation be sent by now.
- mock_send_transaction.assert_called()
-
- # We don't want to specify exactly how the presence EDUs get sent out,
- # could be one per transaction or multiple per transaction. We just want
- # to assert that a) each presence EDU has bounded number of updates, and
- # b) that all updates get sent out.
- presence_edus = []
- for transaction_call in mock_send_transaction.call_args_list:
- json_cb = transaction_call[0][1]
- data = json_cb()
-
- for edu in data["edus"]:
- self.assertEqual(edu.get("edu_type"), EduTypes.PRESENCE)
- presence_edus.append(edu)
-
- # A set of all user presence we see, this should end up matching the
- # number we sent out above.
- seen_users: Set[str] = set()
-
- for edu in presence_edus:
- presence_states = edu["content"]["push"]
-
- # This is where we actually check that the number of presence
- # updates is bounded.
- self.assertLessEqual(len(presence_states), MAX_PRESENCE_STATES_PER_EDU)
-
- seen_users.update(p["user_id"] for p in presence_states)
-
- self.assertEqual(len(seen_users), number_presence_updates_to_send)
-
-
class FederationSenderDevicesTestCases(HomeserverTestCase):
"""
Test federation sending to update devices.
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 88261450b1..8c9ef766d1 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Federation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -67,23 +66,6 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
- def test_failed_edu_causes_500(self) -> None:
- """If the EDU handler fails, /send should return a 500."""
-
- async def failing_handler(_origin: str, _content: JsonDict) -> None:
- raise Exception("bleh")
-
- self.hs.get_federation_registry().register_edu_handler(
- "FAIL_EDU_TYPE", failing_handler
- )
-
- channel = self.make_signed_federation_request(
- "PUT",
- "/_matrix/federation/v1/send/txn",
- {"edus": [{"edu_type": "FAIL_EDU_TYPE", "content": {}}]},
- )
- self.assertEqual(500, channel.code, channel.result)
-
class ServerACLsTestCase(unittest.TestCase):
def test_blocked_server(self) -> None:
diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py
index dab387a504..3d833a2e44 100644
--- a/tests/federation/transport/server/__init__.py
+++ b/tests/federation/transport/server/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index 0e3b41ec4d..595349e70c 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -147,10 +146,3 @@ class BaseFederationAuthorizationTests(unittest.TestCase):
),
("foo", "ed25519:1", "sig", "bar"),
)
- # test that "optional whitespace(s)" (space and tabulation) are allowed between comma-separated auth-param components
- self.assertEqual(
- _parse_auth_header(
- b'X-Matrix origin=foo , key="ed25519:1", sig="sig", destination="bar", extra_field=ignored'
- ),
- ("foo", "ed25519:1", "sig", "bar"),
- )
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index 3d882f99f2..17e8ddb61e 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 166a01c1a2..2f9aefd2b6 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Matrix.org Federation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 0237369998..e50f4bb4a3 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -59,14 +58,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/send/txn_id_1234/",
content={
"edus": [
- {
- "edu_type": EduTypes.DEVICE_LIST_UPDATE,
- "content": {
- "device_id": "QBUAZIFURK",
- "stream_id": 0,
- "user_id": "@user:id",
- },
- },
+ {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}}
],
"pdus": [],
},
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 9ff853a83d..09793c56a9 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1eec0d43b7..2bce5afbd4 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c417431e85..af6e5a8a77 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index f41f7d36ad..36acbd790d 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index d7b54383db..1adc6d8854 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -21,13 +20,12 @@
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
+from synapse.api.constants import AccountDataTypes
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
-from synapse.rest.client import account, login, room
+from synapse.rest.client import account, login
from synapse.server import HomeServer
from synapse.synapse_rust.push import PushRule
-from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -38,7 +36,6 @@ class DeactivateAccountTestCase(HomeserverTestCase):
login.register_servlets,
admin.register_servlets,
account.register_servlets,
- room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,7 +43,6 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.token = self.login("user", "pass")
- self.handler = self.hs.get_room_member_handler()
def _deactivate_my_account(self) -> None:
"""
@@ -344,142 +340,3 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.assertEqual(req.code, 401, req)
self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
-
- def test_deactivate_account_rejects_invites(self) -> None:
- """
- Tests that deactivating an account rejects its invite memberships
- """
- # Create another user and room just for the invitation
- another_user = self.register_user("another_user", "pass")
- token = self.login("another_user", "pass")
- room_id = self.helper.create_room_as(another_user, is_public=False, tok=token)
-
- # Invite user to the created room
- invite_event, _ = self.get_success(
- self.handler.update_membership(
- requester=create_requester(another_user),
- target=UserID.from_string(self.user),
- room_id=room_id,
- action=Membership.INVITE,
- )
- )
-
- # Check that the invite exists
- invite = self.get_success(
- self._store.get_invited_rooms_for_local_user(self.user)
- )
- self.assertEqual(invite[0].event_id, invite_event)
-
- # Deactivate the user
- self._deactivate_my_account()
-
- # Check that the deactivated user has no invites in the room
- after_deactivate_invite = self.get_success(
- self._store.get_invited_rooms_for_local_user(self.user)
- )
- self.assertEqual(len(after_deactivate_invite), 0)
-
- def test_deactivate_account_rejects_knocks(self) -> None:
- """
- Tests that deactivating an account rejects its knock memberships
- """
- # Create another user and room just for the invitation
- another_user = self.register_user("another_user", "pass")
- token = self.login("another_user", "pass")
- room_id = self.helper.create_room_as(
- another_user,
- is_public=False,
- tok=token,
- )
-
- # Allow room to be knocked at
- self.helper.send_state(
- room_id,
- EventTypes.JoinRules,
- {"join_rule": JoinRules.KNOCK},
- tok=token,
- )
-
- # Knock user at the created room
- knock_event, _ = self.get_success(
- self.handler.update_membership(
- requester=create_requester(self.user),
- target=UserID.from_string(self.user),
- room_id=room_id,
- action=Membership.KNOCK,
- )
- )
-
- # Check that the knock exists
- knocks = self.get_success(
- self._store.get_knocked_at_rooms_for_local_user(self.user)
- )
- self.assertEqual(knocks[0].event_id, knock_event)
-
- # Deactivate the user
- self._deactivate_my_account()
-
- # Check that the deactivated user has no knocks
- after_deactivate_knocks = self.get_success(
- self._store.get_knocked_at_rooms_for_local_user(self.user)
- )
- self.assertEqual(len(after_deactivate_knocks), 0)
-
- def test_membership_is_redacted_upon_deactivation(self) -> None:
- """
- Tests that room membership events are redacted if erasure is requested.
- """
- # Create a room
- room_id = self.helper.create_room_as(
- self.user,
- is_public=True,
- tok=self.token,
- )
-
- # Change the displayname
- membership_event, _ = self.get_success(
- self.handler.update_membership(
- requester=create_requester(self.user),
- target=UserID.from_string(self.user),
- room_id=room_id,
- action=Membership.JOIN,
- content={"displayname": "Hello World!"},
- )
- )
-
- # Deactivate the account
- self._deactivate_my_account()
-
- # Get the all membership event IDs
- membership_event_ids = self.get_success(
- self._store.get_membership_event_ids_for_user(self.user, room_id=room_id)
- )
-
- # Get the events incl. JSON
- events = self.get_success(self._store.get_events_as_list(membership_event_ids))
-
- # Validate that there is no displayname in any of the events
- for event in events:
- self.assertTrue("displayname" not in event.content)
-
- def test_rooms_forgotten_upon_deactivation(self) -> None:
- """
- Tests that the user 'forgets' the rooms they left upon deactivation.
- """
- # Create a room
- room_id = self.helper.create_room_as(
- self.user,
- is_public=True,
- tok=self.token,
- )
-
- # Deactivate the account
- self._deactivate_my_account()
-
- # Get all of the user's forgotten rooms
- forgotten_rooms = self.get_success(
- self._store.get_forgotten_rooms_for_user(self.user)
- )
-
- # Validate that the created room is forgotten
- self.assertTrue(room_id in forgotten_rooms)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 080e6a7028..4369a50281 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 4a3e36ffde..03c360cca6 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 Matrix.org Foundation C.I.C.
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -498,27 +496,19 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass")
- self.store = hs.get_datastores().main
-
def test_denied_without_publication_permission(self) -> None:
"""
- Try to create a room, register a valid alias for it, and publish it,
+ Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms.
- The room should be created but not published.
+ (This is used as both a standalone test & as a helper function.)
"""
- room_id = self.helper.create_room_as(
+ self.helper.create_room_as(
self.denied_user_id,
tok=self.denied_access_token,
extra_content=self.data,
is_public=True,
- expect_code=200,
+ expect_code=403,
)
- res = self.get_success(self.store.get_room(room_id))
- assert res is not None
- is_public, _ = res
-
- # room creation completes but room is not published to directory
- self.assertEqual(is_public, False)
def test_allowed_when_creating_private_room(self) -> None:
"""
@@ -536,8 +526,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
def test_allowed_with_publication_permission(self) -> None:
"""
- Try to create a room, register a valid alias for it, and publish it,
+ Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
+ (This is used as both a standalone test & as a helper function.)
"""
self.helper.create_room_as(
self.allowed_user_id,
@@ -549,26 +540,38 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
def test_denied_publication_with_invalid_alias(self) -> None:
"""
- Try to create a room, register an invalid alias for it, and publish it,
+ Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
"""
- room_id = self.helper.create_room_as(
+ self.helper.create_room_as(
self.allowed_user_id,
tok=self.allowed_access_token,
extra_content={"room_alias_name": "foo"},
is_public=True,
- expect_code=200,
+ expect_code=403,
)
- # the room is created with the requested alias, but the room is not published
- res = self.get_success(self.store.get_room(room_id))
- assert res is not None
- is_public, _ = res
+ def test_can_create_as_private_room_after_rejection(self) -> None:
+ """
+ After failing to publish a room with an alias as a user without publish permission,
+ retry as the same user, but without publishing the room.
- self.assertFalse(is_public)
+ This should pass, but used to fail because the alias was registered by the first
+ request, even though the room creation was denied.
+ """
+ self.test_denied_without_publication_permission()
+ self.test_allowed_when_creating_private_room()
+
+ def test_can_create_with_permission_after_rejection(self) -> None:
+ """
+ After failing to publish a room with an alias as a user without publish permission,
+ retry as someone with permission, using the same alias.
- aliases = self.get_success(self.store.get_aliases_for_room(room_id))
- self.assertEqual(aliases[0], "#foo:test")
+ This also used to fail because of the alias having been registered by the first
+ request, leaving it unavailable for any other user's new rooms.
+ """
+ self.test_denied_without_publication_permission()
+ self.test_allowed_with_publication_permission()
class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..aa375fa218 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -43,7 +41,9 @@ from tests.unittest import override_config
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.AsyncMock()
- return self.setup_test_homeserver(application_service_api=self.appservice_api)
+ return self.setup_test_homeserver(
+ federation_client=mock.Mock(), application_service_api=self.appservice_api
+ )
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
@@ -1099,56 +1099,6 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_has_different_keys(self) -> None:
- """check that has_different_keys returns True when the keys provided are different to what
- is in the database."""
- local_user = "@boris:" + self.hs.hostname
- keys1 = {
- "master_key": {
- # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
- "user_id": local_user,
- "usage": ["master"],
- "keys": {
- "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
- },
- }
- }
- self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
- is_different = self.get_success(
- self.handler.has_different_keys(
- local_user,
- {
- "master_key": keys1["master_key"],
- },
- )
- )
- self.assertEqual(is_different, False)
- # change the usage => different keys
- keys1["master_key"]["usage"] = ["develop"]
- is_different = self.get_success(
- self.handler.has_different_keys(
- local_user,
- {
- "master_key": keys1["master_key"],
- },
- )
- )
- self.assertEqual(is_different, True)
- keys1["master_key"]["usage"] = ["master"] # reset
- # change the key => different keys
- keys1["master_key"]["keys"] = {
- "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unIc0rncs": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unIc0rncs"
- }
- is_different = self.get_success(
- self.handler.has_different_keys(
- local_user,
- {
- "master_key": keys1["master_key"],
- },
- )
- )
- self.assertEqual(is_different, True)
-
def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys
@@ -1222,61 +1172,6 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_query_devices_remote_down(self) -> None:
- """Tests that querying keys for a remote user on an unreachable server returns
- results in the "failures" property
- """
-
- remote_user_id = "@test:other"
- local_user_id = "@test:test"
-
- # The backoff code treats time zero as special
- self.reactor.advance(5)
-
- self.hs.get_federation_http_client().agent.request = mock.AsyncMock( # type: ignore[method-assign]
- side_effect=Exception("boop")
- )
-
- e2e_handler = self.hs.get_e2e_keys_handler()
-
- query_result = self.get_success(
- e2e_handler.query_devices(
- {
- "device_keys": {remote_user_id: []},
- },
- timeout=10,
- from_user_id=local_user_id,
- from_device_id="some_device_id",
- )
- )
-
- self.assertEqual(
- query_result["failures"],
- {
- "other": {
- "message": "Failed to send request: Exception: boop",
- "status": 503,
- }
- },
- )
-
- # Do it again: we should hit the backoff
- query_result = self.get_success(
- e2e_handler.query_devices(
- {
- "device_keys": {remote_user_id: []},
- },
- timeout=10,
- from_user_id=local_user_id,
- from_device_id="some_device_id",
- )
- )
-
- self.assertEqual(
- query_result["failures"],
- {"other": {"message": "Not ready for retry", "status": 503}},
- )
-
@parameterized.expand(
[
# The remote homeserver's response indicates that this user has 0/1/2 devices.
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 3ec46402b7..3217639176 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Foundation C.I.C.
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 3fe5b0a1b4..8e2ad42ab8 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -483,7 +482,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
event.room_version,
),
exc=LimitExceededError,
- by=0.5,
)
def _build_and_send_join_event(
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1b83aea579..938c5dec60 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 76ab83d1f7..802f96ed48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -24,7 +23,6 @@ from typing import Tuple
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
@@ -52,15 +50,11 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persist_event_storage_controller = persistence
- self.store = self.hs.get_datastores().main
self.user_id = self.register_user("tester", "foobar")
device_id = "dev-1"
access_token = self.login("tester", "foobar", device_id=device_id)
self.room_id = self.helper.create_room_as(self.user_id, tok=access_token)
- self.private_room_id = self.helper.create_room_as(
- self.user_id, tok=access_token, extra_content={"preset": "private_chat"}
- )
self.requester = create_requester(self.user_id, device_id=device_id)
@@ -290,41 +284,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
AssertionError,
)
- def test_call_invite_event_creation_fails_in_public_room(self) -> None:
- # get prev_events for room
- prev_events = self.get_success(
- self.store.get_prev_events_for_room(self.room_id)
- )
-
- # the invite in a public room should fail
- self.get_failure(
- self.handler.create_event(
- self.requester,
- {
- "type": EventTypes.CallInvite,
- "room_id": self.room_id,
- "sender": self.requester.user.to_string(),
- },
- prev_event_ids=prev_events,
- auth_event_ids=prev_events,
- ),
- SynapseError,
- )
-
- # but a call invite in a private room should succeed
- self.get_success(
- self.handler.create_event(
- self.requester,
- {
- "type": EventTypes.CallInvite,
- "room_id": self.private_room_id,
- "sender": self.requester.user.to_string(),
- },
- prev_event_ids=prev_events,
- auth_event_ids=prev_events,
- )
- )
-
class ServerAclValidationTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 036c539db2..b9761d806d 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -541,8 +540,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
- # Try uploading *different* keys; it should cause a 501 error.
- keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a81501979d..df30a80734 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Quentin Gliech
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ed203eb299..b3248fa491 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index cc630d606c..c68e8c6631 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cb1c6fbb80..d70026c31e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 7c5bec2b76..3c14dfa0a8 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 Šimon Brandner <simon.bra.ag@gmail.com>
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 92487692db..36255270ed 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 213a66ed1a..3e28117e2c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -70,7 +70,6 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
action=Membership.JOIN,
),
LimitExceededError,
- by=0.5,
)
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
@@ -207,7 +206,6 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
remote_room_hosts=[self.OTHER_SERVER_NAME],
),
LimitExceededError,
- by=0.5,
)
# TODO: test that remote joins to a room are rate limited.
@@ -275,7 +273,6 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa
action=Membership.JOIN,
),
LimitExceededError,
- by=0.5,
)
# Try to join as Chris on the original worker. Should get denied because Alice
@@ -288,7 +285,6 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa
action=Membership.JOIN,
),
LimitExceededError,
- by=0.5,
)
@@ -407,24 +403,3 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
self.assertFalse(
self.get_success(self.store.did_forget(self.alice, self.room_id))
)
-
- def test_deduplicate_joins(self) -> None:
- """
- Test that calling /join multiple times does not store a new state group.
- """
-
- self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
-
- sql = "SELECT COUNT(*) FROM state_groups WHERE room_id = ?"
- rows = self.get_success(
- self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
- )
- initial_count = rows[0][0]
-
- self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
- rows = self.get_success(
- self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
- )
- new_count = rows[0][0]
-
- self.assertEqual(initial_count, new_count)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 244a4e7689..929772c412 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 6ab8fda6e7..9d7cd4ee6f 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index cedcea27d9..73ea7fc346 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
deleted file mode 100644
index 96da47f3b9..0000000000
--- a/tests/handlers/test_sliding_sync.py
+++ /dev/null
@@ -1,4567 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import logging
-from copy import deepcopy
-from typing import Dict, List, Optional
-from unittest.mock import patch
-
-from parameterized import parameterized
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.constants import (
- AccountDataTypes,
- EventContentFields,
- EventTypes,
- JoinRules,
- Membership,
- RoomTypes,
-)
-from synapse.api.room_versions import RoomVersions
-from synapse.events import StrippedStateEvent, make_event_from_dict
-from synapse.events.snapshot import EventContext
-from synapse.handlers.sliding_sync import (
- RoomSyncConfig,
- StateValues,
- _RoomMembershipForUser,
-)
-from synapse.rest import admin
-from synapse.rest.client import knock, login, room
-from synapse.server import HomeServer
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import JsonDict, StreamToken, UserID
-from synapse.types.handlers import SlidingSyncConfig
-from synapse.util import Clock
-
-from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.unittest import HomeserverTestCase, TestCase
-
-logger = logging.getLogger(__name__)
-
-
-class RoomSyncConfigTestCase(TestCase):
- def _assert_room_config_equal(
- self,
- actual: RoomSyncConfig,
- expected: RoomSyncConfig,
- message_prefix: Optional[str] = None,
- ) -> None:
- self.assertEqual(actual.timeline_limit, expected.timeline_limit, message_prefix)
-
- # `self.assertEqual(...)` works fine to catch differences but the output is
- # almost impossible to read because of the way it truncates the output and the
- # order doesn't actually matter.
- self.assertCountEqual(
- actual.required_state_map, expected.required_state_map, message_prefix
- )
- for event_type, expected_state_keys in expected.required_state_map.items():
- self.assertCountEqual(
- actual.required_state_map[event_type],
- expected_state_keys,
- f"{message_prefix}: Mismatch for {event_type}",
- )
-
- @parameterized.expand(
- [
- (
- "from_list_config",
- """
- Test that we can convert a `SlidingSyncConfig.SlidingSyncList` to a
- `RoomSyncConfig`.
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (EventTypes.Member, "@foo"),
- (EventTypes.Member, "@bar"),
- (EventTypes.Member, "@baz"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Name: {""},
- EventTypes.Member: {
- "@foo",
- "@bar",
- "@baz",
- },
- EventTypes.CanonicalAlias: {""},
- },
- ),
- ),
- (
- "from_room_subscription",
- """
- Test that we can convert a `SlidingSyncConfig.RoomSubscription` to a
- `RoomSyncConfig`.
- """,
- # Input
- SlidingSyncConfig.RoomSubscription(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (EventTypes.Member, "@foo"),
- (EventTypes.Member, "@bar"),
- (EventTypes.Member, "@baz"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Name: {""},
- EventTypes.Member: {
- "@foo",
- "@bar",
- "@baz",
- },
- EventTypes.CanonicalAlias: {""},
- },
- ),
- ),
- (
- "wildcard",
- """
- Test that a wildcard (*) for both the `event_type` and `state_key` will override
- all other values.
-
- Note: MSC3575 describes different behavior to how we're handling things here but
- since it's not wrong to return more state than requested (`required_state` is
- just the minimum requested), it doesn't matter if we include things that the
- client wanted excluded. This complexity is also under scrutiny, see
- https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1185109050
-
- > One unique exception is when you request all state events via ["*", "*"]. When used,
- > all state events are returned by default, and additional entries FILTER OUT the returned set
- > of state events. These additional entries cannot use '*' themselves.
- > For example, ["*", "*"], ["m.room.member", "@alice:example.com"] will _exclude_ every m.room.member
- > event _except_ for @alice:example.com, and include every other state event.
- > In addition, ["*", "*"], ["m.space.child", "*"] is an error, the m.space.child filter is not
- > required as it would have been returned anyway.
- >
- > -- MSC3575 (https://github.com/matrix-org/matrix-spec-proposals/pull/3575)
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (StateValues.WILDCARD, StateValues.WILDCARD),
- (EventTypes.Member, "@foo"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {StateValues.WILDCARD},
- },
- ),
- ),
- (
- "wildcard_type",
- """
- Test that a wildcard (*) as a `event_type` will override all other values for the
- same `state_key`.
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (StateValues.WILDCARD, ""),
- (EventTypes.Member, "@foo"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {""},
- EventTypes.Member: {"@foo"},
- },
- ),
- ),
- (
- "multiple_wildcard_type",
- """
- Test that multiple wildcard (*) as a `event_type` will override all other values
- for the same `state_key`.
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (StateValues.WILDCARD, ""),
- (EventTypes.Member, "@foo"),
- (StateValues.WILDCARD, "@foo"),
- ("org.matrix.personal_count", "@foo"),
- (EventTypes.Member, "@bar"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {
- "",
- "@foo",
- },
- EventTypes.Member: {"@bar"},
- },
- ),
- ),
- (
- "wildcard_state_key",
- """
- Test that a wildcard (*) as a `state_key` will override all other values for the
- same `event_type`.
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (EventTypes.Member, "@foo"),
- (EventTypes.Member, StateValues.WILDCARD),
- (EventTypes.Member, "@bar"),
- (EventTypes.Member, StateValues.LAZY),
- (EventTypes.Member, "@baz"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Name: {""},
- EventTypes.Member: {
- StateValues.WILDCARD,
- },
- EventTypes.CanonicalAlias: {""},
- },
- ),
- ),
- (
- "wildcard_merge",
- """
- Test that a wildcard (*) entries for the `event_type` and another one for
- `state_key` will play together.
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (StateValues.WILDCARD, ""),
- (EventTypes.Member, "@foo"),
- (EventTypes.Member, StateValues.WILDCARD),
- (EventTypes.Member, "@bar"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {""},
- EventTypes.Member: {StateValues.WILDCARD},
- },
- ),
- ),
- (
- "wildcard_merge2",
- """
- Test that an all wildcard ("*", "*") entry will override any other
- values (including other wildcards).
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (StateValues.WILDCARD, ""),
- (EventTypes.Member, StateValues.WILDCARD),
- (EventTypes.Member, "@foo"),
- # One of these should take precedence over everything else
- (StateValues.WILDCARD, StateValues.WILDCARD),
- (StateValues.WILDCARD, StateValues.WILDCARD),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {StateValues.WILDCARD},
- },
- ),
- ),
- (
- "lazy_members",
- """
- `$LAZY` room members should just be another additional key next to other
- explicit keys. We will unroll the special `$LAZY` meaning later.
- """,
- # Input
- SlidingSyncConfig.SlidingSyncList(
- timeline_limit=10,
- required_state=[
- (EventTypes.Name, ""),
- (EventTypes.Member, "@foo"),
- (EventTypes.Member, "@bar"),
- (EventTypes.Member, StateValues.LAZY),
- (EventTypes.Member, "@baz"),
- (EventTypes.CanonicalAlias, ""),
- ],
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Name: {""},
- EventTypes.Member: {
- "@foo",
- "@bar",
- StateValues.LAZY,
- "@baz",
- },
- EventTypes.CanonicalAlias: {""},
- },
- ),
- ),
- ]
- )
- def test_from_room_config(
- self,
- _test_label: str,
- _test_description: str,
- room_params: SlidingSyncConfig.CommonRoomParameters,
- expected_room_sync_config: RoomSyncConfig,
- ) -> None:
- """
- Test `RoomSyncConfig.from_room_config(room_params)` will result in the `expected_room_sync_config`.
- """
- room_sync_config = RoomSyncConfig.from_room_config(room_params)
-
- self._assert_room_config_equal(
- room_sync_config,
- expected_room_sync_config,
- )
-
- @parameterized.expand(
- [
- (
- "no_direct_overlap",
- # A
- RoomSyncConfig(
- timeline_limit=9,
- required_state_map={
- EventTypes.Name: {""},
- EventTypes.Member: {
- "@foo",
- "@bar",
- },
- },
- ),
- # B
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Member: {
- StateValues.LAZY,
- "@baz",
- },
- EventTypes.CanonicalAlias: {""},
- },
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Name: {""},
- EventTypes.Member: {
- "@foo",
- "@bar",
- StateValues.LAZY,
- "@baz",
- },
- EventTypes.CanonicalAlias: {""},
- },
- ),
- ),
- (
- "wildcard_overlap",
- # A
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {StateValues.WILDCARD},
- },
- ),
- # B
- RoomSyncConfig(
- timeline_limit=9,
- required_state_map={
- EventTypes.Dummy: {StateValues.WILDCARD},
- StateValues.WILDCARD: {"@bar"},
- EventTypes.Member: {"@foo"},
- },
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- StateValues.WILDCARD: {StateValues.WILDCARD},
- },
- ),
- ),
- (
- "state_type_wildcard_overlap",
- # A
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Dummy: {"dummy"},
- StateValues.WILDCARD: {
- "",
- "@foo",
- },
- EventTypes.Member: {"@bar"},
- },
- ),
- # B
- RoomSyncConfig(
- timeline_limit=9,
- required_state_map={
- EventTypes.Dummy: {"dummy2"},
- StateValues.WILDCARD: {
- "",
- "@bar",
- },
- EventTypes.Member: {"@foo"},
- },
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Dummy: {
- "dummy",
- "dummy2",
- },
- StateValues.WILDCARD: {
- "",
- "@foo",
- "@bar",
- },
- },
- ),
- ),
- (
- "state_key_wildcard_overlap",
- # A
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Dummy: {"dummy"},
- EventTypes.Member: {StateValues.WILDCARD},
- "org.matrix.flowers": {StateValues.WILDCARD},
- },
- ),
- # B
- RoomSyncConfig(
- timeline_limit=9,
- required_state_map={
- EventTypes.Dummy: {StateValues.WILDCARD},
- EventTypes.Member: {StateValues.WILDCARD},
- "org.matrix.flowers": {"tulips"},
- },
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Dummy: {StateValues.WILDCARD},
- EventTypes.Member: {StateValues.WILDCARD},
- "org.matrix.flowers": {StateValues.WILDCARD},
- },
- ),
- ),
- (
- "state_type_and_state_key_wildcard_merge",
- # A
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Dummy: {"dummy"},
- StateValues.WILDCARD: {
- "",
- "@foo",
- },
- EventTypes.Member: {"@bar"},
- },
- ),
- # B
- RoomSyncConfig(
- timeline_limit=9,
- required_state_map={
- EventTypes.Dummy: {"dummy2"},
- StateValues.WILDCARD: {""},
- EventTypes.Member: {StateValues.WILDCARD},
- },
- ),
- # Expected
- RoomSyncConfig(
- timeline_limit=10,
- required_state_map={
- EventTypes.Dummy: {
- "dummy",
- "dummy2",
- },
- StateValues.WILDCARD: {
- "",
- "@foo",
- },
- EventTypes.Member: {StateValues.WILDCARD},
- },
- ),
- ),
- ]
- )
- def test_combine_room_sync_config(
- self,
- _test_label: str,
- a: RoomSyncConfig,
- b: RoomSyncConfig,
- expected: RoomSyncConfig,
- ) -> None:
- """
- Combine A into B and B into A to make sure we get the same result.
- """
- # Since we're mutating these in place, make a copy for each of our trials
- room_sync_config_a = deepcopy(a)
- room_sync_config_b = deepcopy(b)
-
- # Combine B into A
- room_sync_config_a.combine_room_sync_config(room_sync_config_b)
-
- self._assert_room_config_equal(room_sync_config_a, expected, "B into A")
-
- # Since we're mutating these in place, make a copy for each of our trials
- room_sync_config_a = deepcopy(a)
- room_sync_config_b = deepcopy(b)
-
- # Combine A into B
- room_sync_config_b.combine_room_sync_config(room_sync_config_a)
-
- self._assert_room_config_equal(room_sync_config_b, expected, "A into B")
-
-
-class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
- """
- Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns
- the correct list of rooms IDs.
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
- self.storage_controllers = hs.get_storage_controllers()
-
- def test_no_rooms(self) -> None:
- """
- Test when the user has never joined any rooms before
- """
- user1_id = self.register_user("user1", "pass")
- # user1_tok = self.login(user1_id, "pass")
-
- now_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=now_token,
- to_token=now_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), set())
-
- def test_get_newly_joined_room(self) -> None:
- """
- Test that rooms that the user has newly_joined show up. newly_joined is when you
- join after the `from_token` and <= `to_token`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room_token = self.event_sources.get_current_token()
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(room_id, user1_id, tok=user1_tok)
-
- after_room_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room_token,
- to_token=after_room_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), {room_id})
- # It should be pointing to the join event (latest membership event in the
- # from/to range)
- self.assertEqual(
- room_id_results[room_id].event_id,
- join_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
- # We should be considered `newly_joined` because we joined during the token
- # range
- self.assertEqual(room_id_results[room_id].newly_joined, True)
- self.assertEqual(room_id_results[room_id].newly_left, False)
-
- def test_get_already_joined_room(self) -> None:
- """
- Test that rooms that the user is already joined show up.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(room_id, user1_id, tok=user1_tok)
-
- after_room_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room_token,
- to_token=after_room_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), {room_id})
- # It should be pointing to the join event (latest membership event in the
- # from/to range)
- self.assertEqual(
- room_id_results[room_id].event_id,
- join_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id].newly_joined, False)
- self.assertEqual(room_id_results[room_id].newly_left, False)
-
- def test_get_invited_banned_knocked_room(self) -> None:
- """
- Test that rooms that the user is invited to, banned from, and knocked on show
- up.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room_token = self.event_sources.get_current_token()
-
- # Setup the invited room (user2 invites user1 to the room)
- invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- invite_response = self.helper.invite(
- invited_room_id, targ=user1_id, tok=user2_tok
- )
-
- # Setup the ban room (user2 bans user1 from the room)
- ban_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(ban_room_id, user1_id, tok=user1_tok)
- ban_response = self.helper.ban(
- ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok
- )
-
- # Setup the knock room (user1 knocks on the room)
- knock_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, room_version=RoomVersions.V7.identifier
- )
- self.helper.send_state(
- knock_room_id,
- EventTypes.JoinRules,
- {"join_rule": JoinRules.KNOCK},
- tok=user2_tok,
- )
- # User1 knocks on the room
- knock_channel = self.make_request(
- "POST",
- "/_matrix/client/r0/knock/%s" % (knock_room_id,),
- b"{}",
- user1_tok,
- )
- self.assertEqual(knock_channel.code, 200, knock_channel.result)
- knock_room_membership_state_event = self.get_success(
- self.storage_controllers.state.get_current_state_event(
- knock_room_id, EventTypes.Member, user1_id
- )
- )
- assert knock_room_membership_state_event is not None
-
- after_room_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room_token,
- to_token=after_room_token,
- )
- )
-
- # Ensure that the invited, ban, and knock rooms show up
- self.assertEqual(
- room_id_results.keys(),
- {
- invited_room_id,
- ban_room_id,
- knock_room_id,
- },
- )
- # It should be pointing to the the respective membership event (latest
- # membership event in the from/to range)
- self.assertEqual(
- room_id_results[invited_room_id].event_id,
- invite_response["event_id"],
- )
- self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[invited_room_id].newly_left, False)
-
- self.assertEqual(
- room_id_results[ban_room_id].event_id,
- ban_response["event_id"],
- )
- self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_left, False)
-
- self.assertEqual(
- room_id_results[knock_room_id].event_id,
- knock_room_membership_state_event.event_id,
- )
- self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
- self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
- self.assertEqual(room_id_results[knock_room_id].newly_left, False)
-
- def test_get_kicked_room(self) -> None:
- """
- Test that a room that the user was kicked from still shows up. When the user
- comes back to their client, they should see that they were kicked.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Setup the kick room (user2 kicks user1 from the room)
- kick_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(kick_room_id, user1_id, tok=user1_tok)
- # Kick user1 from the room
- kick_response = self.helper.change_membership(
- room=kick_room_id,
- src=user2_id,
- targ=user1_id,
- tok=user2_tok,
- membership=Membership.LEAVE,
- extra_data={
- "reason": "Bad manners",
- },
- )
-
- after_kick_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_kick_token,
- to_token=after_kick_token,
- )
- )
-
- # The kicked room should show up
- self.assertEqual(room_id_results.keys(), {kick_room_id})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[kick_room_id].event_id,
- kick_response["event_id"],
- )
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
- # We should *NOT* be `newly_joined` because we were not joined at the the time
- # of the `to_token`.
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
-
- def test_forgotten_rooms(self) -> None:
- """
- Forgotten rooms do not show up even if we forget after the from/to range.
-
- Ideally, we would be able to track when the `/forget` happens and apply it
- accordingly in the token range but the forgotten flag is only an extra bool in
- the `room_memberships` table.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Setup a normal room that we leave. This won't show up in the sync response
- # because we left it before our token but is good to check anyway.
- leave_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(leave_room_id, user1_id, tok=user1_tok)
- self.helper.leave(leave_room_id, user1_id, tok=user1_tok)
-
- # Setup the ban room (user2 bans user1 from the room)
- ban_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(ban_room_id, user1_id, tok=user1_tok)
- self.helper.ban(ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
-
- # Setup the kick room (user2 kicks user1 from the room)
- kick_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(kick_room_id, user1_id, tok=user1_tok)
- # Kick user1 from the room
- self.helper.change_membership(
- room=kick_room_id,
- src=user2_id,
- targ=user1_id,
- tok=user2_tok,
- membership=Membership.LEAVE,
- extra_data={
- "reason": "Bad manners",
- },
- )
-
- before_room_forgets = self.event_sources.get_current_token()
-
- # Forget the room after we already have our tokens. This doesn't change
- # the membership event itself but will mark it internally in Synapse
- channel = self.make_request(
- "POST",
- f"/_matrix/client/r0/rooms/{leave_room_id}/forget",
- content={},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.result)
- channel = self.make_request(
- "POST",
- f"/_matrix/client/r0/rooms/{ban_room_id}/forget",
- content={},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.result)
- channel = self.make_request(
- "POST",
- f"/_matrix/client/r0/rooms/{kick_room_id}/forget",
- content={},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.result)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room_forgets,
- to_token=before_room_forgets,
- )
- )
-
- # We shouldn't see the room because it was forgotten
- self.assertEqual(room_id_results.keys(), set())
-
- def test_newly_left_rooms(self) -> None:
- """
- Test that newly_left are marked properly
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Leave before we calculate the `from_token`
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave during the from_token/to_token range (newly_left)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
-
- after_room2_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room2_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
-
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response1["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` or `newly_left` because that happened before
- # the from/to range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_response2["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` because we are instead `newly_left`
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, True)
-
- def test_no_joins_after_to_token(self) -> None:
- """
- Rooms we join after the `to_token` should *not* show up. See condition "1b)"
- comments in the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Room join after our `to_token` shouldn't show up
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response1["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_join_during_range_and_left_room_after_to_token(self) -> None:
- """
- Room still shows up if we left the room but were joined during the
- from_token/to_token. See condition "1a)" comments in the
- `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave the room after we already have our tokens
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # We should still see the room because we were joined during the
- # from_token/to_token time period.
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response": join_response["event_id"],
- "leave_response": leave_response["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_join_before_range_and_left_room_after_to_token(self) -> None:
- """
- Room still shows up if we left the room but were joined before the `from_token`
- so it should show up. See condition "1a)" comments in the
- `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave the room after we already have our tokens
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # We should still see the room because we were joined before the `from_token`
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response": join_response["event_id"],
- "leave_response": leave_response["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_kicked_before_range_and_left_after_to_token(self) -> None:
- """
- Room still shows up if we left the room but were kicked before the `from_token`
- so it should show up. See condition "1a)" comments in the
- `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Setup the kick room (user2 kicks user1 from the room)
- kick_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- join_response1 = self.helper.join(kick_room_id, user1_id, tok=user1_tok)
- # Kick user1 from the room
- kick_response = self.helper.change_membership(
- room=kick_room_id,
- src=user2_id,
- targ=user1_id,
- tok=user2_tok,
- membership=Membership.LEAVE,
- extra_data={
- "reason": "Bad manners",
- },
- )
-
- after_kick_token = self.event_sources.get_current_token()
-
- # Leave the room after we already have our tokens
- #
- # We have to join before we can leave (leave -> leave isn't a valid transition
- # or at least it doesn't work in Synapse, 403 forbidden)
- join_response2 = self.helper.join(kick_room_id, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_kick_token,
- to_token=after_kick_token,
- )
- )
-
- # We shouldn't see the room because it was forgotten
- self.assertEqual(room_id_results.keys(), {kick_room_id})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[kick_room_id].event_id,
- kick_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response1": join_response1["event_id"],
- "kick_response": kick_response["event_id"],
- "join_response2": join_response2["event_id"],
- "leave_response": leave_response["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
- # We should *NOT* be `newly_joined` because we were kicked
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
-
- def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
- """
- Newly left room should show up. But we're also testing that joining and leaving
- after the `to_token` doesn't mess with the results. See condition "2)" and "1a)"
- comments in the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join and leave the room during the from/to range
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Join and leave the room after we already have our tokens
- join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should still show up because it's newly_left during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response1["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response1": join_response1["event_id"],
- "leave_response1": leave_response1["event_id"],
- "join_response2": join_response2["event_id"],
- "leave_response2": leave_response2["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` because we are actually `newly_left` during
- # the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
- def test_newly_left_during_range_and_join_after_to_token(self) -> None:
- """
- Newly left room should show up. But we're also testing that joining after the
- `to_token` doesn't mess with the results. See condition "2)" and "1b)" comments
- in the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join and leave the room during the from/to range
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Join the room after we already have our tokens
- join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should still show up because it's newly_left during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response1["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response1": join_response1["event_id"],
- "leave_response1": leave_response1["event_id"],
- "join_response2": join_response2["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` because we are actually `newly_left` during
- # the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
- def test_no_from_token(self) -> None:
- """
- Test that if we don't provide a `from_token`, we get all the rooms that we had
- membership in up to the `to_token`.
-
- Providing `from_token` only really has the effect that it marks rooms as
- `newly_left` in the response.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
-
- # Join room1
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Join and leave the room2 before the `to_token`
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Join the room2 after we already have our tokens
- self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_room1_token,
- )
- )
-
- # Only rooms we were joined to before the `to_token` should show up
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
-
- # Room1
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response1["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined`/`newly_left` because there is no
- # `from_token` to define a "live" range to compare against
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_response2["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because there is no
- # `from_token` to define a "live" range to compare against
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
-
- def test_from_token_ahead_of_to_token(self) -> None:
- """
- Test when the provided `from_token` comes after the `to_token`. We should
- basically expect the same result as having no `from_token`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
-
- # Join room1 before `to_token`
- join_room1_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Join and leave the room2 before `to_token`
- _join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok)
- leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
-
- # Note: These are purposely swapped. The `from_token` should come after
- # the `to_token` in this test
- to_token = self.event_sources.get_current_token()
-
- # Join room2 after `to_token`
- _join_room2_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- # --------
-
- # Join room3 after `to_token`
- _join_room3_response1 = self.helper.join(room_id3, user1_id, tok=user1_tok)
-
- # Join and leave the room4 after `to_token`
- _join_room4_response1 = self.helper.join(room_id4, user1_id, tok=user1_tok)
- _leave_room4_response1 = self.helper.leave(room_id4, user1_id, tok=user1_tok)
-
- # Note: These are purposely swapped. The `from_token` should come after the
- # `to_token` in this test
- from_token = self.event_sources.get_current_token()
-
- # Join the room4 after we already have our tokens
- self.helper.join(room_id4, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=from_token,
- to_token=to_token,
- )
- )
-
- # In the "current" state snapshot, we're joined to all of the rooms but in the
- # from/to token range...
- self.assertIncludes(
- room_id_results.keys(),
- {
- # Included because we were joined before both tokens
- room_id1,
- # Included because we had membership before the to_token
- room_id2,
- # Excluded because we joined after the `to_token`
- # room_id3,
- # Excluded because we joined after the `to_token`
- # room_id4,
- },
- exact=True,
- )
-
- # Room1
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_room1_response1["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1`
- # before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_room2_response1["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
-
- def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
- """
- Test old left rooms. But we're also testing that joining and leaving after the
- `to_token` doesn't mess with the results. See condition "1a)" comments in the
- `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join and leave the room before the from/to range
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Join and leave the room after we already have our tokens
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room1_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_leave_before_range_and_join_after_to_token(self) -> None:
- """
- Test old left room. But we're also testing that joining after the `to_token`
- doesn't mess with the results. See condition "1b)" comments in the
- `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join and leave the room before the from/to range
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Join the room after we already have our tokens
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room1_token,
- )
- )
-
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_join_leave_multiple_times_during_range_and_after_to_token(
- self,
- ) -> None:
- """
- Join and leave multiple times shouldn't affect rooms from showing up. It just
- matters that we had membership in the from/to range. But we're also testing that
- joining and leaving after the `to_token` doesn't mess with the results.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join, leave, join back to the room during the from/to range
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave and Join the room multiple times after we already have our tokens
- leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because it was newly_left and joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response2["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response1": join_response1["event_id"],
- "leave_response1": leave_response1["event_id"],
- "join_response2": join_response2["event_id"],
- "leave_response2": leave_response2["event_id"],
- "join_response3": join_response3["event_id"],
- "leave_response3": leave_response3["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- # We should *NOT* be `newly_left` because we joined during the token range and
- # was still joined at the end of the range
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_join_leave_multiple_times_before_range_and_after_to_token(
- self,
- ) -> None:
- """
- Join and leave multiple times before the from/to range shouldn't affect rooms
- from showing up. It just matters that we had membership in the
- from/to range. But we're also testing that joining and leaving after the
- `to_token` doesn't mess with the results.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join, leave, join back to the room before the from/to range
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave and Join the room multiple times after we already have our tokens
- leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because we were joined before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response2["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response1": join_response1["event_id"],
- "leave_response1": leave_response1["event_id"],
- "join_response2": join_response2["event_id"],
- "leave_response2": leave_response2["event_id"],
- "join_response3": join_response3["event_id"],
- "leave_response3": leave_response3["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_invite_before_range_and_join_leave_after_to_token(
- self,
- ) -> None:
- """
- Make it look like we joined after the token range but we were invited before the
- from/to range so the room should still show up. See condition "1a)" comments in
- the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
-
- # Invited to the room before the token
- invite_response = self.helper.invite(
- room_id1, src=user2_id, targ=user1_id, tok=user2_tok
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Join and leave the room after we already have our tokens
- join_respsonse = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because we were invited before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- invite_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "invite_response": invite_response["event_id"],
- "join_respsonse": join_respsonse["event_id"],
- "leave_response": leave_response["event_id"],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE)
- # We should *NOT* be `newly_joined` because we were only invited before the
- # token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_join_and_display_name_changes_in_token_range(
- self,
- ) -> None:
- """
- Test that we point to the correct membership event within the from/to range even
- if there are multiple `join` membership events in a row indicating
- `displayname`/`avatar_url` updates.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
- # Update the displayname during the token range
- displayname_change_during_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname during token range",
- },
- tok=user1_tok,
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Update the displayname after the token range
- displayname_change_after_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname after token range",
- },
- tok=user1_tok,
- )
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- displayname_change_during_token_range_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response": join_response["event_id"],
- "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
- "event_id"
- ],
- "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
- "event_id"
- ],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_display_name_changes_in_token_range(
- self,
- ) -> None:
- """
- Test that we point to the correct membership event within the from/to range even
- if there is `displayname`/`avatar_url` updates.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Update the displayname during the token range
- displayname_change_during_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname during token range",
- },
- tok=user1_tok,
- )
-
- after_change1_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_change1_token,
- )
- )
-
- # Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- displayname_change_during_token_range_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response": join_response["event_id"],
- "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
- "event_id"
- ],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_display_name_changes_before_and_after_token_range(
- self,
- ) -> None:
- """
- Test that we point to the correct membership event even though there are no
- membership events in the from/range but there are `displayname`/`avatar_url`
- changes before/after the token range.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
- # Update the displayname before the token range
- displayname_change_before_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname during token range",
- },
- tok=user1_tok,
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Update the displayname after the token range
- displayname_change_after_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname after token range",
- },
- tok=user1_tok,
- )
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because we were joined before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- displayname_change_before_token_range_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response": join_response["event_id"],
- "displayname_change_before_token_range_response": displayname_change_before_token_range_response[
- "event_id"
- ],
- "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
- "event_id"
- ],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_display_name_changes_leave_after_token_range(
- self,
- ) -> None:
- """
- Test that we point to the correct membership event within the from/to range even
- if there are multiple `join` membership events in a row indicating
- `displayname`/`avatar_url` updates and we leave after the `to_token`.
-
- See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
- # Update the displayname during the token range
- displayname_change_during_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname during token range",
- },
- tok=user1_tok,
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Update the displayname after the token range
- displayname_change_after_token_range_response = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname after token range",
- },
- tok=user1_tok,
- )
-
- # Leave after the token
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- displayname_change_during_token_range_response["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response": join_response["event_id"],
- "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
- "event_id"
- ],
- "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
- "event_id"
- ],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_display_name_changes_join_after_token_range(
- self,
- ) -> None:
- """
- Test that multiple `join` membership events (after the `to_token`) in a row
- indicating `displayname`/`avatar_url` updates doesn't affect the results (we
- joined after the token range so it shouldn't show up)
-
- See condition "1b)" comments in the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
-
- after_room1_token = self.event_sources.get_current_token()
-
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- # Update the displayname after the token range
- self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname after token range",
- },
- tok=user1_tok,
- )
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room shouldn't show up because we joined after the from/to range
- self.assertEqual(room_id_results.keys(), set())
-
- def test_newly_joined_with_leave_join_in_token_range(
- self,
- ) -> None:
- """
- Test that even though we're joined before the token range, if we leave and join
- within the token range, it's still counted as `newly_joined`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave and join back during the token range
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- after_more_changes_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_more_changes_token,
- )
- )
-
- # Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_response2["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be considered `newly_joined` because there is some non-join event in
- # between our latest join event.
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_newly_joined_only_joins_during_token_range(
- self,
- ) -> None:
- """
- Test that a join and more joins caused by display name changes, all during the
- token range, still count as `newly_joined`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # Join, leave, join back to the room before the from/to range
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- # Update the displayname during the token range (looks like another join)
- displayname_change_during_token_range_response1 = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname during token range",
- },
- tok=user1_tok,
- )
- # Update the displayname during the token range (looks like another join)
- displayname_change_during_token_range_response2 = self.helper.send_state(
- room_id1,
- event_type=EventTypes.Member,
- state_key=user1_id,
- body={
- "membership": Membership.JOIN,
- "displayname": "displayname during token range",
- },
- tok=user1_tok,
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room1_token,
- to_token=after_room1_token,
- )
- )
-
- # Room should show up because it was newly_left and joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- displayname_change_during_token_range_response2["event_id"],
- "Corresponding map to disambiguate the opaque event IDs: "
- + str(
- {
- "join_response1": join_response1["event_id"],
- "displayname_change_during_token_range_response1": displayname_change_during_token_range_response1[
- "event_id"
- ],
- "displayname_change_during_token_range_response2": displayname_change_during_token_range_response2[
- "event_id"
- ],
- }
- ),
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we first joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- def test_multiple_rooms_are_not_confused(
- self,
- ) -> None:
- """
- Test that multiple rooms are not confused as we fixup the list. This test is
- spawning from a real world bug in the code where I was accidentally using
- `event.room_id` in one of the fix-up loops but the `event` being referenced was
- actually from a different loop.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # We create the room with user2 so the room isn't left with no members when we
- # leave and can still re-join.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
-
- # Invited and left the room before the token
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- # Invited to room2
- invite_room2_response = self.helper.invite(
- room_id2, src=user2_id, targ=user1_id, tok=user2_tok
- )
-
- before_room3_token = self.event_sources.get_current_token()
-
- # Invited and left room3 during the from/to range
- room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- self.helper.invite(room_id3, src=user2_id, targ=user1_id, tok=user2_tok)
- leave_room3_response = self.helper.leave(room_id3, user1_id, tok=user1_tok)
-
- after_room3_token = self.event_sources.get_current_token()
-
- # Join and leave the room after we already have our tokens
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- # Leave room2
- self.helper.leave(room_id2, user1_id, tok=user1_tok)
- # Leave room3
- self.helper.leave(room_id3, user1_id, tok=user1_tok)
-
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_room3_token,
- to_token=after_room3_token,
- )
- )
-
- self.assertEqual(
- room_id_results.keys(),
- {
- # Left before the from/to range
- room_id1,
- # Invited before the from/to range
- room_id2,
- # `newly_left` during the from/to range
- room_id3,
- },
- )
-
- # Room1
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_room1_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left
- # before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id2].event_id,
- invite_room2_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE)
- # We should *NOT* be `newly_joined`/`newly_left` because we were invited before
- # the token range
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
-
- # Room3
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id3].event_id,
- leave_room3_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE)
- # We should be `newly_left` because we were invited and left during
- # the token range
- self.assertEqual(room_id_results[room_id3].newly_joined, False)
- self.assertEqual(room_id_results[room_id3].newly_left, True)
-
- def test_state_reset(self) -> None:
- """
- Test a state reset scenario where the user gets removed from the room (when
- there is no corresponding leave event)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # The room where the state reset will happen
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Join another room so we don't hit the short-circuit and return early if they
- # have no room membership
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- before_reset_token = self.event_sources.get_current_token()
-
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_events",
- )
- )
- self.get_success(
- self.store.db_pool.simple_delete(
- table="local_current_membership",
- keyvalues={
- "room_id": room_id1,
- "user_id": user1_id,
- },
- desc="state reset user in local_current_membership",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
- )
-
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
-
- after_reset_token = self.event_sources.get_current_token()
-
- # The function under test
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_reset_token,
- to_token=after_reset_token,
- )
- )
-
- # Room1 should show up because it was `newly_left` via state reset during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
- # It should be pointing to no event because we were removed from the room
- # without a corresponding leave event
- self.assertEqual(
- room_id_results[room_id1].event_id,
- None,
- )
- # State reset caused us to leave the room and there is no corresponding leave event
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- # We should be `newly_left` because we were removed via state reset during the from/to range
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
-
-class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase):
- """
- Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with
- sharded event stream_writers enabled
- """
-
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def default_config(self) -> dict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
-
- # Enable shared event stream_writers
- config["stream_writers"] = {"events": ["worker1", "worker2", "worker3"]}
- config["instance_map"] = {
- "main": {"host": "testserv", "port": 8765},
- "worker1": {"host": "testserv", "port": 1001},
- "worker2": {"host": "testserv", "port": 1002},
- "worker3": {"host": "testserv", "port": 1003},
- }
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
-
- def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
- """
- Create a room with a specific room_id. We use this so that that we have a
- consistent room_id across test runs that hashes to the same value and will be
- sharded to a known worker in the tests.
- """
-
- # We control the room ID generation by patching out the
- # `_generate_room_id` method
- with patch(
- "synapse.handlers.room.RoomCreationHandler._generate_room_id"
- ) as mock:
- mock.side_effect = lambda: room_id
- self.helper.create_room_as(user_id, tok=tok)
-
- def test_sharded_event_persisters(self) -> None:
- """
- This test should catch bugs that would come from flawed stream position
- (`stream_ordering`) comparisons or making `RoomStreamToken`'s naively. To
- compare event positions properly, you need to consider both the `instance_name`
- and `stream_ordering` together.
-
- The test creates three event persister workers and a room that is sharded to
- each worker. On worker2, we make the event stream position stuck so that it lags
- behind the other workers and we start getting `RoomStreamToken` that have an
- `instance_map` component (i.e. q`m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`).
-
- We then send some events to advance the stream positions of worker1 and worker3
- but worker2 is lagging behind because it's stuck. We are specifically testing
- that `get_room_membership_for_user_at_to_token(from_token=xxx, to_token=xxx)` should work
- correctly in these adverse conditions.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- self.make_worker_hs(
- "synapse.app.generic_worker",
- {"worker_name": "worker1"},
- )
-
- worker_hs2 = self.make_worker_hs(
- "synapse.app.generic_worker",
- {"worker_name": "worker2"},
- )
-
- self.make_worker_hs(
- "synapse.app.generic_worker",
- {"worker_name": "worker3"},
- )
-
- # Specially crafted room IDs that get persisted on different workers.
- #
- # Sharded to worker1
- room_id1 = "!fooo:test"
- # Sharded to worker2
- room_id2 = "!bar:test"
- # Sharded to worker3
- room_id3 = "!quux:test"
-
- # Create rooms on the different workers.
- self._create_room(room_id1, user2_id, user2_tok)
- self._create_room(room_id2, user2_id, user2_tok)
- self._create_room(room_id3, user2_id, user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
- # Leave room2
- leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok)
- join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok)
- # Leave room3
- self.helper.leave(room_id3, user1_id, tok=user1_tok)
-
- # Ensure that the events were sharded to different workers.
- pos1 = self.get_success(
- self.store.get_position_for_event(join_response1["event_id"])
- )
- self.assertEqual(pos1.instance_name, "worker1")
- pos2 = self.get_success(
- self.store.get_position_for_event(join_response2["event_id"])
- )
- self.assertEqual(pos2.instance_name, "worker2")
- pos3 = self.get_success(
- self.store.get_position_for_event(join_response3["event_id"])
- )
- self.assertEqual(pos3.instance_name, "worker3")
-
- before_stuck_activity_token = self.event_sources.get_current_token()
-
- # We now gut wrench into the events stream `MultiWriterIdGenerator` on worker2 to
- # mimic it getting stuck persisting an event. This ensures that when we send an
- # event on worker1/worker3 we end up in a state where worker2 events stream
- # position lags that on worker1/worker3, resulting in a RoomStreamToken with a
- # non-empty `instance_map` component.
- #
- # Worker2's event stream position will not advance until we call `__aexit__`
- # again.
- worker_store2 = worker_hs2.get_datastores().main
- assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
- actx = worker_store2._stream_id_gen.get_next()
- self.get_success(actx.__aenter__())
-
- # For room_id1/worker1: leave and join the room to advance the stream position
- # and generate membership changes.
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- join_room1_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
- # For room_id2/worker2: which is currently stuck, join the room.
- join_on_worker2_response = self.helper.join(room_id2, user1_id, tok=user1_tok)
- # For room_id3/worker3: leave and join the room to advance the stream position
- # and generate membership changes.
- self.helper.leave(room_id3, user1_id, tok=user1_tok)
- join_on_worker3_response = self.helper.join(room_id3, user1_id, tok=user1_tok)
-
- # Get a token while things are stuck after our activity
- stuck_activity_token = self.event_sources.get_current_token()
- # Let's make sure we're working with a token that has an `instance_map`
- self.assertNotEqual(len(stuck_activity_token.room_key.instance_map), 0)
-
- # Just double check that the join event on worker2 (that is stuck) happened
- # after the position recorded for worker2 in the token but before the max
- # position in the token. This is crucial for the behavior we're trying to test.
- join_on_worker2_pos = self.get_success(
- self.store.get_position_for_event(join_on_worker2_response["event_id"])
- )
- # Ensure the join technially came after our token
- self.assertGreater(
- join_on_worker2_pos.stream,
- stuck_activity_token.room_key.get_stream_pos_for_instance("worker2"),
- )
- # But less than the max stream position of some other worker
- self.assertLess(
- join_on_worker2_pos.stream,
- # max
- stuck_activity_token.room_key.get_max_stream_pos(),
- )
-
- # Just double check that the join event on worker3 happened after the min stream
- # value in the token but still within the position recorded for worker3. This is
- # crucial for the behavior we're trying to test.
- join_on_worker3_pos = self.get_success(
- self.store.get_position_for_event(join_on_worker3_response["event_id"])
- )
- # Ensure the join came after the min but still encapsulated by the token
- self.assertGreaterEqual(
- join_on_worker3_pos.stream,
- # min
- stuck_activity_token.room_key.stream,
- )
- self.assertLessEqual(
- join_on_worker3_pos.stream,
- stuck_activity_token.room_key.get_stream_pos_for_instance("worker3"),
- )
-
- # We finish the fake persisting an event we started above and advance worker2's
- # event stream position (unstuck worker2).
- self.get_success(actx.__aexit__(None, None, None))
-
- # The function under test
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
- from_token=before_stuck_activity_token,
- to_token=stuck_activity_token,
- )
- )
-
- self.assertEqual(
- room_id_results.keys(),
- {
- room_id1,
- room_id2,
- room_id3,
- },
- )
-
- # Room1
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- join_room1_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_room2_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # room_id2 should *NOT* be considered `newly_left` because we left before the
- # from/to range and the join event during the range happened while worker2 was
- # stuck. This means that from the perspective of the master, where the
- # `stuck_activity_token` is generated, the stream position for worker2 wasn't
- # advanced to the join yet. Looking at the `instance_map`, the join technically
- # comes after `stuck_activity_token`.
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
-
- # Room3
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id3].event_id,
- join_on_worker3_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id3].newly_joined, True)
- self.assertEqual(room_id_results[room_id3].newly_left, False)
-
-
-class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
- """
- Tests Sliding Sync handler `filter_rooms_relevant_for_sync()` to make sure it returns
- the correct list of rooms IDs.
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
- self.storage_controllers = hs.get_storage_controllers()
-
- def _get_sync_room_ids_for_user(
- self,
- user: UserID,
- to_token: StreamToken,
- from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Get the rooms the user should be syncing with
- """
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- user=user,
- from_token=from_token,
- to_token=to_token,
- )
- )
- filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
- user=user,
- room_membership_for_user_map=room_membership_for_user_map,
- )
- )
-
- return filtered_sync_room_map
-
- def test_no_rooms(self) -> None:
- """
- Test when the user has never joined any rooms before
- """
- user1_id = self.register_user("user1", "pass")
- # user1_tok = self.login(user1_id, "pass")
-
- now_token = self.event_sources.get_current_token()
-
- room_id_results = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=now_token,
- to_token=now_token,
- )
-
- self.assertEqual(room_id_results.keys(), set())
-
- def test_basic_rooms(self) -> None:
- """
- Test that rooms that the user is joined to, invited to, banned from, and knocked
- on show up.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room_token = self.event_sources.get_current_token()
-
- join_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(join_room_id, user1_id, tok=user1_tok)
-
- # Setup the invited room (user2 invites user1 to the room)
- invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- invite_response = self.helper.invite(
- invited_room_id, targ=user1_id, tok=user2_tok
- )
-
- # Setup the ban room (user2 bans user1 from the room)
- ban_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(ban_room_id, user1_id, tok=user1_tok)
- ban_response = self.helper.ban(
- ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok
- )
-
- # Setup the knock room (user1 knocks on the room)
- knock_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, room_version=RoomVersions.V7.identifier
- )
- self.helper.send_state(
- knock_room_id,
- EventTypes.JoinRules,
- {"join_rule": JoinRules.KNOCK},
- tok=user2_tok,
- )
- # User1 knocks on the room
- knock_channel = self.make_request(
- "POST",
- "/_matrix/client/r0/knock/%s" % (knock_room_id,),
- b"{}",
- user1_tok,
- )
- self.assertEqual(knock_channel.code, 200, knock_channel.result)
- knock_room_membership_state_event = self.get_success(
- self.storage_controllers.state.get_current_state_event(
- knock_room_id, EventTypes.Member, user1_id
- )
- )
- assert knock_room_membership_state_event is not None
-
- after_room_token = self.event_sources.get_current_token()
-
- room_id_results = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=before_room_token,
- to_token=after_room_token,
- )
-
- # Ensure that the invited, ban, and knock rooms show up
- self.assertEqual(
- room_id_results.keys(),
- {
- join_room_id,
- invited_room_id,
- ban_room_id,
- knock_room_id,
- },
- )
- # It should be pointing to the the respective membership event (latest
- # membership event in the from/to range)
- self.assertEqual(
- room_id_results[join_room_id].event_id,
- join_response["event_id"],
- )
- self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN)
- self.assertEqual(room_id_results[join_room_id].newly_joined, True)
- self.assertEqual(room_id_results[join_room_id].newly_left, False)
-
- self.assertEqual(
- room_id_results[invited_room_id].event_id,
- invite_response["event_id"],
- )
- self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[invited_room_id].newly_left, False)
-
- self.assertEqual(
- room_id_results[ban_room_id].event_id,
- ban_response["event_id"],
- )
- self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_left, False)
-
- self.assertEqual(
- room_id_results[knock_room_id].event_id,
- knock_room_membership_state_event.event_id,
- )
- self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
- self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
- self.assertEqual(room_id_results[knock_room_id].newly_left, False)
-
- def test_only_newly_left_rooms_show_up(self) -> None:
- """
- Test that `newly_left` rooms still show up in the sync response but rooms that
- were left before the `from_token` don't show up. See condition "2)" comments in
- the `get_room_membership_for_user_at_to_token()` method.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Leave before we calculate the `from_token`
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Leave during the from_token/to_token range (newly_left)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
-
- after_room2_token = self.event_sources.get_current_token()
-
- room_id_results = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=after_room1_token,
- to_token=after_room2_token,
- )
-
- # Only the `newly_left` room should show up
- self.assertEqual(room_id_results.keys(), {room_id2})
- self.assertEqual(
- room_id_results[room_id2].event_id,
- _leave_response2["event_id"],
- )
- # We should *NOT* be `newly_joined` because we are instead `newly_left`
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, True)
-
- def test_get_kicked_room(self) -> None:
- """
- Test that a room that the user was kicked from still shows up. When the user
- comes back to their client, they should see that they were kicked.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Setup the kick room (user2 kicks user1 from the room)
- kick_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok, is_public=True
- )
- self.helper.join(kick_room_id, user1_id, tok=user1_tok)
- # Kick user1 from the room
- kick_response = self.helper.change_membership(
- room=kick_room_id,
- src=user2_id,
- targ=user1_id,
- tok=user2_tok,
- membership=Membership.LEAVE,
- extra_data={
- "reason": "Bad manners",
- },
- )
-
- after_kick_token = self.event_sources.get_current_token()
-
- room_id_results = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=after_kick_token,
- to_token=after_kick_token,
- )
-
- # The kicked room should show up
- self.assertEqual(room_id_results.keys(), {kick_room_id})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[kick_room_id].event_id,
- kick_response["event_id"],
- )
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
- # We should *NOT* be `newly_joined` because we were not joined at the the time
- # of the `to_token`.
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
-
- def test_state_reset(self) -> None:
- """
- Test a state reset scenario where the user gets removed from the room (when
- there is no corresponding leave event)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # The room where the state reset will happen
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Join another room so we don't hit the short-circuit and return early if they
- # have no room membership
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- before_reset_token = self.event_sources.get_current_token()
-
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_events",
- )
- )
- self.get_success(
- self.store.db_pool.simple_delete(
- table="local_current_membership",
- keyvalues={
- "room_id": room_id1,
- "user_id": user1_id,
- },
- desc="state reset user in local_current_membership",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
- )
-
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
-
- after_reset_token = self.event_sources.get_current_token()
-
- # The function under test
- room_id_results = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=before_reset_token,
- to_token=after_reset_token,
- )
-
- # Room1 should show up because it was `newly_left` via state reset during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
- # It should be pointing to no event because we were removed from the room
- # without a corresponding leave event
- self.assertEqual(
- room_id_results[room_id1].event_id,
- None,
- )
- # State reset caused us to leave the room and there is no corresponding leave event
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- # We should be `newly_left` because we were removed via state reset during the from/to range
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
-
-class FilterRoomsTestCase(HomeserverTestCase):
- """
- Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms
- correctly.
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
-
- def _get_sync_room_ids_for_user(
- self,
- user: UserID,
- to_token: StreamToken,
- from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Get the rooms the user should be syncing with
- """
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- user=user,
- from_token=from_token,
- to_token=to_token,
- )
- )
- filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
- user=user,
- room_membership_for_user_map=room_membership_for_user_map,
- )
- )
-
- return filtered_sync_room_map
-
- def _create_dm_room(
- self,
- inviter_user_id: str,
- inviter_tok: str,
- invitee_user_id: str,
- invitee_tok: str,
- ) -> str:
- """
- Helper to create a DM room as the "inviter" and invite the "invitee" user to the room. The
- "invitee" user also will join the room. The `m.direct` account data will be set
- for both users.
- """
-
- # Create a room and send an invite the other user
- room_id = self.helper.create_room_as(
- inviter_user_id,
- is_public=False,
- tok=inviter_tok,
- )
- self.helper.invite(
- room_id,
- src=inviter_user_id,
- targ=invitee_user_id,
- tok=inviter_tok,
- extra_data={"is_direct": True},
- )
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
-
- # Mimic the client setting the room as a direct message in the global account
- # data
- self.get_success(
- self.store.add_account_data_for_user(
- invitee_user_id,
- AccountDataTypes.DIRECT,
- {inviter_user_id: [room_id]},
- )
- )
- self.get_success(
- self.store.add_account_data_for_user(
- inviter_user_id,
- AccountDataTypes.DIRECT,
- {invitee_user_id: [room_id]},
- )
- )
-
- return room_id
-
- _remote_invite_count: int = 0
-
- def _create_remote_invite_room_for_user(
- self,
- invitee_user_id: str,
- unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
- ) -> str:
- """
- Create a fake invite for a remote room and persist it.
-
- We don't have any state for these kind of rooms and can only rely on the
- stripped state included in the unsigned portion of the invite event to identify
- the room.
-
- Args:
- invitee_user_id: The person being invited
- unsigned_invite_room_state: List of stripped state events to assist the
- receiver in identifying the room.
-
- Returns:
- The room ID of the remote invite room
- """
- invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
-
- invite_event_dict = {
- "room_id": invite_room_id,
- "sender": "@inviter:remote_server",
- "state_key": invitee_user_id,
- "depth": 1,
- "origin_server_ts": 1,
- "type": EventTypes.Member,
- "content": {"membership": Membership.INVITE},
- "auth_events": [],
- "prev_events": [],
- }
- if unsigned_invite_room_state is not None:
- serialized_stripped_state_events = []
- for stripped_event in unsigned_invite_room_state:
- serialized_stripped_state_events.append(
- {
- "type": stripped_event.type,
- "state_key": stripped_event.state_key,
- "sender": stripped_event.sender,
- "content": stripped_event.content,
- }
- )
-
- invite_event_dict["unsigned"] = {
- "invite_room_state": serialized_stripped_state_events
- }
-
- invite_event = make_event_from_dict(
- invite_event_dict,
- room_version=RoomVersions.V10,
- )
- invite_event.internal_metadata.outlier = True
- invite_event.internal_metadata.out_of_band_membership = True
-
- self.get_success(
- self.store.maybe_store_room_on_outlier_membership(
- room_id=invite_room_id, room_version=invite_event.room_version
- )
- )
- context = EventContext.for_outlier(self.hs.get_storage_controllers())
- persist_controller = self.hs.get_storage_controllers().persistence
- assert persist_controller is not None
- self.get_success(persist_controller.persist_event(invite_event, context))
-
- self._remote_invite_count += 1
-
- return invite_room_id
-
- def test_filter_dm_rooms(self) -> None:
- """
- Test `filter.is_dm` for DM rooms
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a DM room
- dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_dm=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_dm=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id})
-
- # Try with `is_dm=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_dm=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_rooms(self) -> None:
- """
- Test `filter.is_encrypted` for encrypted rooms
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_server_left_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that everyone has left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Leave the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
- # Leave the room
- self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_server_left_room2(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that everyone has
- left.
-
- There is still someone local who is invited to the rooms but that doesn't affect
- whether the server is participating in the room (users need to be joined).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- _user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Invite user2
- self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
- # Invite user2
- self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_after_we_left(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that was encrypted
- after we left the room (make sure we don't just use the current state)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Leave the room
- self.helper.join(room_id, user1_id, tok=user1_tok)
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a room that will be encrypted
- encrypted_after_we_left_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok
- )
- # Leave the room
- self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
- self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
-
- # Encrypt the room after we've left
- self.helper.send_state(
- encrypted_after_we_left_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user2_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # Even though we left the room before it was encrypted, we still see it because
- # someone else on our server is still participating in the room and we "leak"
- # the current state to the left user. But we consider the room encryption status
- # to not be a secret given it's often set at the start of the room and it's one
- # of the stripped state events that is normally handed out.
- self.assertEqual(
- truthy_filtered_room_map.keys(), {encrypted_after_we_left_room_id}
- )
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # Even though we left the room before it was encrypted... (see comment above)
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_room_no_stripped_state(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- room without any `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room without any `unsigned.invite_room_state`
- _remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id, None
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out whether
- # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out whether
- # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_encrypted_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- encrypted room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # indicating that the room is encrypted.
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- },
- ),
- StrippedStateEvent(
- type=EventTypes.RoomEncryption,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
- },
- ),
- ],
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is encrypted
- # according to the stripped state
- self.assertEqual(
- truthy_filtered_room_map.keys(), {encrypted_room_id, remote_invite_room_id}
- )
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is encrypted
- # according to the stripped state
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_unencrypted_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- unencrypted room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # but don't set any room encryption event.
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- },
- ),
- # No room encryption event
- ],
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is unencrypted
- # according to the stripped state
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear because it is unencrypted according to
- # the stripped state
- self.assertEqual(
- falsy_filtered_room_map.keys(), {room_id, remote_invite_room_id}
- )
-
- def test_filter_invite_rooms(self) -> None:
- """
- Test `filter.is_invite` for rooms that the user has been invited to
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- # Create a room that user1 is invited to
- invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_invite=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_invite=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id})
-
- # Try with `is_invite=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_invite=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_room_types(self) -> None:
- """
- Test `filter.room_types` for different room types
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- # Create an arbitrarily typed room
- foo_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {
- EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
- }
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- # Try finding normal rooms and spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None, RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id, space_room_id})
-
- # Try finding an arbitrary room type
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=["org.matrix.foobarbaz"]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {foo_room_id})
-
- def test_filter_not_room_types(self) -> None:
- """
- Test `filter.not_room_types` for different room types
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- # Create an arbitrarily typed room
- foo_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {
- EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
- }
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding *NOT* normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(not_room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id, foo_room_id})
-
- # Try finding *NOT* spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- not_room_types=[RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id, foo_room_id})
-
- # Try finding *NOT* normal rooms or spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- not_room_types=[None, RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {foo_room_id})
-
- # Test how it behaves when we have both `room_types` and `not_room_types`.
- # `not_room_types` should win.
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None], not_room_types=[None]
- ),
- after_rooms_token,
- )
- )
-
- # Nothing matches because nothing is both a normal room and not a normal room
- self.assertEqual(filtered_room_map.keys(), set())
-
- # Test how it behaves when we have both `room_types` and `not_room_types`.
- # `not_room_types` should win.
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None, RoomTypes.SPACE], not_room_types=[None]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_server_left_room(self) -> None:
- """
- Test that we can apply a `filter.room_types` against a room that everyone has left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Leave the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- # Leave the room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_server_left_room2(self) -> None:
- """
- Test that we can apply a `filter.room_types` against a room that everyone has left.
-
- There is still someone local who is invited to the rooms but that doesn't affect
- whether the server is participating in the room (users need to be joined).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- _user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Invite user2
- self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- # Invite user2
- self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_with_remote_invite_room_no_stripped_state(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- room without any `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room without any `unsigned.invite_room_state`
- _remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id, None
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out what
- # room type it is (no stripped state, `unsigned.invite_room_state`)
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out what
- # room type it is (no stripped state, `unsigned.invite_room_state`)
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_with_remote_invite_space(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- to a space room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state` indicating
- # that it is a space room
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- # Specify that it is a space room
- EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
- },
- ),
- ],
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is a space room
- # according to the stripped state
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is a space room
- # according to the stripped state
- self.assertEqual(
- filtered_room_map.keys(), {space_room_id, remote_invite_room_id}
- )
-
- def test_filter_room_types_with_remote_invite_normal_room(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- to a normal room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # but the create event does not specify a room type (normal room)
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- # No room type means this is a normal room
- },
- ),
- ],
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is a normal room
- # according to the stripped state (no room type)
- self.assertEqual(filtered_room_map.keys(), {room_id, remote_invite_room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is a normal room
- # according to the stripped state (no room type)
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
-
-class SortRoomsTestCase(HomeserverTestCase):
- """
- Tests Sliding Sync handler `sort_rooms()` to make sure it sorts/orders rooms
- correctly.
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
-
- def _get_sync_room_ids_for_user(
- self,
- user: UserID,
- to_token: StreamToken,
- from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Get the rooms the user should be syncing with
- """
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- user=user,
- from_token=from_token,
- to_token=to_token,
- )
- )
- filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
- user=user,
- room_membership_for_user_map=room_membership_for_user_map,
- )
- )
-
- return filtered_sync_room_map
-
- def test_sort_activity_basic(self) -> None:
- """
- Rooms with newer activity are sorted first.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- )
- room_id2 = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Sort the rooms (what we're testing)
- sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
- sync_room_map=sync_room_map,
- to_token=after_rooms_token,
- )
- )
-
- self.assertEqual(
- [room_membership.room_id for room_membership in sorted_sync_rooms],
- [room_id2, room_id1],
- )
-
- @parameterized.expand(
- [
- (Membership.LEAVE,),
- (Membership.INVITE,),
- (Membership.KNOCK,),
- (Membership.BAN,),
- ]
- )
- def test_activity_after_xxx(self, room1_membership: str) -> None:
- """
- When someone has left/been invited/knocked/been banned from a room, they
- shouldn't take anything into account after that membership event.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create the rooms as user2 so we can have user1 with a clean slate to work from
- # and join in whatever order we need for the tests.
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- # If we're testing knocks, set the room to knock
- if room1_membership == Membership.KNOCK:
- self.helper.send_state(
- room_id1,
- EventTypes.JoinRules,
- {"join_rule": JoinRules.KNOCK},
- tok=user2_tok,
- )
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
- room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
-
- # Here is the activity with user1 that will determine the sort of the rooms
- # (room2, room1, room3)
- self.helper.join(room_id3, user1_id, tok=user1_tok)
- if room1_membership == Membership.LEAVE:
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- elif room1_membership == Membership.INVITE:
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- elif room1_membership == Membership.KNOCK:
- self.helper.knock(room_id1, user1_id, tok=user1_tok)
- elif room1_membership == Membership.BAN:
- self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- # Activity before the token but the user is only been xxx to this room so it
- # shouldn't be taken into account
- self.helper.send(room_id1, "activity in room1", tok=user2_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Activity after the token. Just make it in a different order than what we
- # expect to make sure we're not taking the activity after the token into
- # account.
- self.helper.send(room_id1, "activity in room1", tok=user2_tok)
- self.helper.send(room_id2, "activity in room2", tok=user2_tok)
- self.helper.send(room_id3, "activity in room3", tok=user2_tok)
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Sort the rooms (what we're testing)
- sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
- sync_room_map=sync_room_map,
- to_token=after_rooms_token,
- )
- )
-
- self.assertEqual(
- [room_membership.room_id for room_membership in sorted_sync_rooms],
- [room_id2, room_id1, room_id3],
- "Corresponding map to disambiguate the opaque room IDs: "
- + str(
- {
- "room_id1": room_id1,
- "room_id2": room_id2,
- "room_id3": room_id3,
- }
- ),
- )
-
- def test_default_bump_event_types(self) -> None:
- """
- Test that we only consider the *latest* event in the room when sorting (not
- `bump_event_types`).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- )
- message_response = self.helper.send(room_id1, "message in room1", tok=user1_tok)
- room_id2 = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- )
- self.helper.send(room_id2, "message in room2", tok=user1_tok)
-
- # Send a reaction in room1 which isn't in `DEFAULT_BUMP_EVENT_TYPES` but we only
- # care about sorting by the *latest* event in the room.
- self.helper.send_event(
- room_id1,
- type=EventTypes.Reaction,
- content={
- "m.relates_to": {
- "event_id": message_response["event_id"],
- "key": "👍",
- "rel_type": "m.annotation",
- }
- },
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Sort the rooms (what we're testing)
- sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
- sync_room_map=sync_room_map,
- to_token=after_rooms_token,
- )
- )
-
- self.assertEqual(
- [room_membership.room_id for room_membership in sorted_sync_rooms],
- # room1 sorts before room2 because it has the latest event (the reaction).
- # We only care about the *latest* event in the room.
- [room_id1, room_id2],
- )
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index fa55f76916..37904926e3 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -17,46 +17,25 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import Collection, ContextManager, List, Optional
+from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
-from parameterized import parameterized
-
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
-from synapse.api.filtering import FilterCollection, Filtering
-from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
-from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion
+from synapse.api.filtering import Filtering
+from synapse.api.room_versions import RoomVersions
+from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
-from synapse.types import (
- JsonDict,
- MultiWriterStreamToken,
- RoomStreamToken,
- StreamKeyType,
- UserID,
- create_requester,
-)
+from synapse.types import UserID, create_requester
from synapse.util import Clock
import tests.unittest
import tests.utils
-_request_key = 0
-
-
-def generate_request_key() -> SyncRequestKey:
- global _request_key
- _request_key += 1
- return ("request_key", _request_key)
-
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
@@ -89,23 +68,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- requester,
- sync_config,
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config)
)
# Test that global lock works
self.auth_blocking._hs_disabled = True
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(
- requester,
- sync_config,
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- ),
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@@ -116,12 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
requester = create_requester(user_id2)
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(
- requester,
- sync_config,
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- ),
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@@ -140,10 +104,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
requester = create_requester(user)
initial_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
- requester,
- sync_config=generate_sync_config(user, device_id="dev"),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
+ requester, sync_config=generate_sync_config(user, device_id="dev")
)
)
@@ -174,10 +135,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# The rooms should appear in the sync response.
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
- requester,
- sync_config=generate_sync_config(user),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
+ requester, sync_config=generate_sync_config(user)
)
)
self.assertIn(joined_room, [r.room_id for r in result.joined])
@@ -189,8 +147,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
requester,
sync_config=generate_sync_config(user, device_id="dev"),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
since_token=initial_result.next_batch,
)
)
@@ -210,8 +166,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
# Blow away caches (supported room versions can only change due to a restart).
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store.get_rooms_for_user.invalidate_all()
- self.store._get_rooms_for_local_user_where_membership_is_inner.invalidate_all()
self.store._get_event_cache.clear()
self.store._event_ref.clear()
@@ -219,10 +175,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Get a new request key.
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
- requester,
- sync_config=generate_sync_config(user),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
+ requester, sync_config=generate_sync_config(user)
)
)
self.assertNotIn(joined_room, [r.room_id for r in result.joined])
@@ -234,8 +187,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
requester,
sync_config=generate_sync_config(user, device_id="dev"),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
since_token=initial_result.next_batch,
)
)
@@ -275,10 +226,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Do a sync as Alice to get the latest event in the room.
alice_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
- create_requester(owner),
- generate_sync_config(owner),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
+ create_requester(owner), generate_sync_config(owner)
)
)
self.assertEqual(len(alice_sync_result.joined), 1)
@@ -298,12 +246,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
eve_requester = create_requester(eve)
eve_sync_config = generate_sync_config(eve)
eve_sync_after_ban: SyncResult = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- eve_requester,
- eve_sync_config,
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
+ self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config)
)
# Sanity check this sync result. We shouldn't be joined to the room.
@@ -312,7 +255,13 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve tries to join the room. We monkey patch the internal logic which selects
# the prev_events used when creating the join event, such that the ban does not
# precede the join.
- with self._patch_get_latest_events([last_room_creation_event_id]):
+ mocked_get_prev_events = patch.object(
+ self.hs.get_datastores().main,
+ "get_prev_events_for_room",
+ new_callable=AsyncMock,
+ return_value=[last_room_creation_event_id],
+ )
+ with mocked_get_prev_events:
self.helper.join(room_id, eve, tok=eve_token)
# Eve makes a second, incremental sync.
@@ -320,8 +269,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
eve_requester,
eve_sync_config,
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
since_token=eve_sync_after_ban.next_batch,
)
)
@@ -333,748 +280,25 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
eve_requester,
eve_sync_config,
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
since_token=None,
)
)
self.assertEqual(eve_initial_sync_after_join.joined, [])
- def test_state_includes_changes_on_forks(self) -> None:
- """State changes that happen on a fork of the DAG must be included in `state`
-
- Given the following DAG:
-
- E1
- ↗ ↖
- | S2
- | ↑
- --|------|----
- | |
- E3 |
- ↖ /
- E4
-
- ... and a filter that means we only return 2 events, represented by the dashed
- horizontal line: `S2` must be included in the `state` section.
- """
- alice = self.register_user("alice", "password")
- alice_tok = self.login(alice, "password")
- alice_requester = create_requester(alice)
- room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
-
- # Do an initial sync as Alice to get a known starting point.
- initial_sync_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(alice),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- last_room_creation_event_id = (
- initial_sync_result.joined[0].timeline.events[-1].event_id
- )
-
- # Send a state event, and a regular event, both using the same prev ID
- with self._patch_get_latest_events([last_room_creation_event_id]):
- s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
- "event_id"
- ]
- e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
-
- # Send a final event, joining the two branches of the dag
- e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
-
- # do an incremental sync, with a filter that will ensure we only get two of
- # the three new events.
- incremental_sync = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(
- alice,
- filter_collection=FilterCollection(
- self.hs, {"room": {"timeline": {"limit": 2}}}
- ),
- ),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=initial_sync_result.next_batch,
- )
- )
-
- # The state event should appear in the 'state' section of the response.
- room_sync = incremental_sync.joined[0]
- self.assertEqual(room_sync.room_id, room_id)
- self.assertTrue(room_sync.timeline.limited)
- self.assertEqual(
- [e.event_id for e in room_sync.timeline.events],
- [e3_event, e4_event],
- )
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
-
- def test_state_includes_changes_on_forks_when_events_excluded(self) -> None:
- """A variation on the previous test, but where one event is filtered
-
- The DAG is the same as the previous test, but E4 is excluded by the filter.
-
- E1
- ↗ ↖
- | S2
- | ↑
- --|------|----
- | |
- E3 |
- ↖ /
- (E4)
-
- """
-
- alice = self.register_user("alice", "password")
- alice_tok = self.login(alice, "password")
- alice_requester = create_requester(alice)
- room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
-
- # Do an initial sync as Alice to get a known starting point.
- initial_sync_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(alice),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- last_room_creation_event_id = (
- initial_sync_result.joined[0].timeline.events[-1].event_id
- )
-
- # Send a state event, and a regular event, both using the same prev ID
- with self._patch_get_latest_events([last_room_creation_event_id]):
- s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
- "event_id"
- ]
- e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
-
- # Send a final event, joining the two branches of the dag
- self.helper.send(room_id, "e4", type="not_a_normal_message", tok=alice_tok)[
- "event_id"
- ]
-
- # do an incremental sync, with a filter that will only return E3, excluding S2
- # and E4.
- incremental_sync = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(
- alice,
- filter_collection=FilterCollection(
- self.hs,
- {
- "room": {
- "timeline": {
- "limit": 1,
- "not_types": ["not_a_normal_message"],
- }
- }
- },
- ),
- ),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=initial_sync_result.next_batch,
- )
- )
-
- # The state event should appear in the 'state' section of the response.
- room_sync = incremental_sync.joined[0]
- self.assertEqual(room_sync.room_id, room_id)
- self.assertTrue(room_sync.timeline.limited)
- self.assertEqual(
- [e.event_id for e in room_sync.timeline.events],
- [e3_event],
- )
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
-
- def test_state_includes_changes_on_long_lived_forks(self) -> None:
- """State changes that happen on a fork of the DAG must be included in `state`
-
- Given the following DAG:
-
- E1
- ↗ ↖
- | S2
- | ↑
- --|------|----
- E3 |
- --|------|----
- | E4
- | |
-
- ... and a filter that means we only return 1 event, represented by the dashed
- horizontal lines: `S2` must be included in the `state` section on the second sync.
- """
- alice = self.register_user("alice", "password")
- alice_tok = self.login(alice, "password")
- alice_requester = create_requester(alice)
- room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
-
- # Do an initial sync as Alice to get a known starting point.
- initial_sync_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(alice),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- last_room_creation_event_id = (
- initial_sync_result.joined[0].timeline.events[-1].event_id
- )
-
- # Send a state event, and a regular event, both using the same prev ID
- with self._patch_get_latest_events([last_room_creation_event_id]):
- s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
- "event_id"
- ]
- e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
-
- # Do an incremental sync, this will return E3 but *not* S2 at this
- # point.
- incremental_sync = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(
- alice,
- filter_collection=FilterCollection(
- self.hs, {"room": {"timeline": {"limit": 1}}}
- ),
- ),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=initial_sync_result.next_batch,
- )
- )
- room_sync = incremental_sync.joined[0]
- self.assertEqual(room_sync.room_id, room_id)
- self.assertTrue(room_sync.timeline.limited)
- self.assertEqual(
- [e.event_id for e in room_sync.timeline.events],
- [e3_event],
- )
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [],
- )
-
- # Now send another event that points to S2, but not E3.
- with self._patch_get_latest_events([s2_event]):
- e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
-
- # Doing an incremental sync should return S2 in state.
- incremental_sync = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(
- alice,
- filter_collection=FilterCollection(
- self.hs, {"room": {"timeline": {"limit": 1}}}
- ),
- ),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=incremental_sync.next_batch,
- )
- )
- room_sync = incremental_sync.joined[0]
- self.assertEqual(room_sync.room_id, room_id)
- self.assertFalse(room_sync.timeline.limited)
- self.assertEqual(
- [e.event_id for e in room_sync.timeline.events],
- [e4_event],
- )
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
-
- def test_state_includes_changes_on_ungappy_syncs(self) -> None:
- """Test `state` where the sync is not gappy.
-
- We start with a DAG like this:
-
- E1
- ↗ ↖
- | S2
- |
- --|---
- |
- E3
-
- ... and initialsync with `limit=1`, represented by the horizontal dashed line.
- At this point, we do not expect S2 to appear in the response at all (since
- it is excluded from the timeline by the `limit`, and the state is based on the
- state after the most recent event before the sync token (E3), which doesn't
- include S2.
-
- Now more events arrive, and we do an incremental sync:
-
- E1
- ↗ ↖
- | S2
- | ↑
- E3 |
- ↑ |
- --|------|----
- | |
- E4 |
- ↖ /
- E5
-
- This is the last chance for us to tell the client about S2, so it *must* be
- included in the response.
- """
- alice = self.register_user("alice", "password")
- alice_tok = self.login(alice, "password")
- alice_requester = create_requester(alice)
- room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
-
- # Do an initial sync to get a known starting point.
- initial_sync_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(alice),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- last_room_creation_event_id = (
- initial_sync_result.joined[0].timeline.events[-1].event_id
- )
-
- # Send a state event, and a regular event, both using the same prev ID
- with self._patch_get_latest_events([last_room_creation_event_id]):
- s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
- "event_id"
- ]
- e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
-
- # Another initial sync, with limit=1
- initial_sync_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(
- alice,
- filter_collection=FilterCollection(
- self.hs, {"room": {"timeline": {"limit": 1}}}
- ),
- ),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- room_sync = initial_sync_result.joined[0]
- self.assertEqual(room_sync.room_id, room_id)
- self.assertEqual(
- [e.event_id for e in room_sync.timeline.events],
- [e3_event],
- )
- self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
-
- # More events, E4 and E5
- with self._patch_get_latest_events([e3_event]):
- e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
- e5_event = self.helper.send(room_id, "e5", tok=alice_tok)["event_id"]
-
- # Now incremental sync
- incremental_sync = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- alice_requester,
- generate_sync_config(alice),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=initial_sync_result.next_batch,
- )
- )
-
- # The state event should appear in the 'state' section of the response.
- room_sync = incremental_sync.joined[0]
- self.assertEqual(room_sync.room_id, room_id)
- self.assertFalse(room_sync.timeline.limited)
- self.assertEqual(
- [e.event_id for e in room_sync.timeline.events],
- [e4_event, e5_event],
- )
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
-
- @parameterized.expand(
- [
- (False, False),
- (True, False),
- (False, True),
- (True, True),
- ]
- )
- def test_archived_rooms_do_not_include_state_after_leave(
- self, initial_sync: bool, empty_timeline: bool
- ) -> None:
- """If the user leaves the room, state changes that happen after they leave are not returned.
-
- We try with both a zero and a normal timeline limit,
- and we try both an initial sync and an incremental sync for both.
- """
- if empty_timeline and not initial_sync:
- # FIXME synapse doesn't return the room at all in this situation!
- self.skipTest("Synapse does not correctly handle this case")
-
- # Alice creates the room, and bob joins.
- alice = self.register_user("alice", "password")
- alice_tok = self.login(alice, "password")
-
- bob = self.register_user("bob", "password")
- bob_tok = self.login(bob, "password")
- bob_requester = create_requester(bob)
-
- room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
- self.helper.join(room_id, bob, tok=bob_tok)
-
- initial_sync_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- bob_requester,
- generate_sync_config(bob),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
-
- # Alice sends a message and a state
- before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[
- "event_id"
- ]
- before_state_event = self.helper.send_state(
- room_id, "test_state", {"body": "before"}, tok=alice_tok
- )["event_id"]
-
- # Bob leaves
- leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"]
-
- # Alice sends some more stuff
- self.helper.send(room_id, "after", tok=alice_tok)["event_id"]
- self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[
- "event_id"
- ]
-
- # And now, Bob resyncs.
- filter_dict: JsonDict = {"room": {"include_leave": True}}
- if empty_timeline:
- filter_dict["room"]["timeline"] = {"limit": 0}
- sync_room_result = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- bob_requester,
- generate_sync_config(
- bob, filter_collection=FilterCollection(self.hs, filter_dict)
- ),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=None if initial_sync else initial_sync_result.next_batch,
- )
- ).archived[0]
-
- if empty_timeline:
- # The timeline should be empty
- self.assertEqual(sync_room_result.timeline.events, [])
-
- # And the state should include the leave event...
- self.assertEqual(
- sync_room_result.state[("m.room.member", bob)].event_id, leave_event
- )
- # ... and the state change before he left.
- self.assertEqual(
- sync_room_result.state[("test_state", "")].event_id, before_state_event
- )
- else:
- # The last three events in the timeline should be those leading up to the
- # leave
- self.assertEqual(
- [e.event_id for e in sync_room_result.timeline.events[-3:]],
- [before_message_event, before_state_event, leave_event],
- )
- # ... And the state should be empty
- self.assertEqual(sync_room_result.state, {})
-
- def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager:
- """Monkey-patch `get_prev_events_for_room`
-
- Returns a context manager which will replace the implementation of
- `get_prev_events_for_room` with one which returns `latest_events`.
- """
- return patch.object(
- self.hs.get_datastores().main,
- "get_prev_events_for_room",
- new_callable=AsyncMock,
- return_value=latest_events,
- )
-
- def test_call_invite_in_public_room_not_returned(self) -> None:
- user = self.register_user("alice", "password")
- tok = self.login(user, "password")
- room_id = self.helper.create_room_as(user, is_public=True, tok=tok)
- self.handler = self.hs.get_federation_handler()
- federation_event_handler = self.hs.get_federation_event_handler()
-
- async def _check_event_auth(
- origin: Optional[str], event: EventBase, context: EventContext
- ) -> None:
- pass
-
- federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign]
- self.client = self.hs.get_federation_client()
-
- async def _check_sigs_and_hash_for_pulled_events_and_fetch(
- dest: str, pdus: Collection[EventBase], room_version: RoomVersion
- ) -> List[EventBase]:
- return list(pdus)
-
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
-
- prev_events = self.get_success(self.store.get_prev_events_for_room(room_id))
-
- # create a call invite event
- call_event = event_from_pdu_json(
- {
- "type": EventTypes.CallInvite,
- "content": {},
- "room_id": room_id,
- "sender": user,
- "depth": 32,
- "prev_events": prev_events,
- "auth_events": prev_events,
- "origin_server_ts": self.clock.time_msec(),
- },
- RoomVersions.V10,
- )
-
- self.assertEqual(
- self.get_success(
- federation_event_handler.on_receive_pdu("test.serv", call_event)
- ),
- None,
- )
-
- # check that it is in DB
- recent_event = self.get_success(self.store.get_prev_events_for_room(room_id))
- self.assertIn(call_event.event_id, recent_event)
-
- # but that it does not come down /sync in public room
- sync_result: SyncResult = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- create_requester(user),
- generate_sync_config(user),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- event_ids = []
- for event in sync_result.joined[0].timeline.events:
- event_ids.append(event.event_id)
- self.assertNotIn(call_event.event_id, event_ids)
-
- # it will come down in a private room, though
- user2 = self.register_user("bob", "password")
- tok2 = self.login(user2, "password")
- private_room_id = self.helper.create_room_as(
- user2, is_public=False, tok=tok2, extra_content={"preset": "private_chat"}
- )
-
- priv_prev_events = self.get_success(
- self.store.get_prev_events_for_room(private_room_id)
- )
- private_call_event = event_from_pdu_json(
- {
- "type": EventTypes.CallInvite,
- "content": {},
- "room_id": private_room_id,
- "sender": user,
- "depth": 32,
- "prev_events": priv_prev_events,
- "auth_events": priv_prev_events,
- "origin_server_ts": self.clock.time_msec(),
- },
- RoomVersions.V10,
- )
-
- self.assertEqual(
- self.get_success(
- federation_event_handler.on_receive_pdu("test.serv", private_call_event)
- ),
- None,
- )
-
- recent_events = self.get_success(
- self.store.get_prev_events_for_room(private_room_id)
- )
- self.assertIn(private_call_event.event_id, recent_events)
-
- private_sync_result: SyncResult = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- create_requester(user2),
- generate_sync_config(user2),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
- priv_event_ids = []
- for event in private_sync_result.joined[0].timeline.events:
- priv_event_ids.append(event.event_id)
-
- self.assertIn(private_call_event.event_id, priv_event_ids)
- def test_push_rules_with_bad_account_data(self) -> None:
- """Some old accounts have managed to set a `m.push_rules` account data,
- which we should ignore in /sync response.
- """
-
- user = self.register_user("alice", "password")
-
- # Insert the bad account data.
- self.get_success(
- self.store.add_account_data_for_user(user, AccountDataTypes.PUSH_RULES, {})
- )
-
- sync_result: SyncResult = self.get_success(
- self.sync_handler.wait_for_sync_for_user(
- create_requester(user),
- generate_sync_config(user),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- )
- )
-
- for account_dict in sync_result.account_data:
- if account_dict["type"] == AccountDataTypes.PUSH_RULES:
- # We should have lots of push rules here, rather than the bad
- # empty data.
- self.assertNotEqual(account_dict["content"], {})
- return
-
- self.fail("No push rules found")
-
- def test_wait_for_future_sync_token(self) -> None:
- """Test that if we receive a token that is ahead of our current token,
- we'll wait until the stream position advances.
-
- This can happen if replication streams start lagging, and the client's
- previous sync request was serviced by a worker ahead of ours.
- """
- user = self.register_user("alice", "password")
-
- # We simulate a lagging stream by getting a stream ID from the ID gen
- # and then waiting to mark it as "persisted".
- presence_id_gen = self.store.get_presence_stream_id_gen()
- ctx_mgr = presence_id_gen.get_next()
- stream_id = self.get_success(ctx_mgr.__aenter__())
-
- # Create the new token based on the stream ID above.
- current_token = self.hs.get_event_sources().get_current_token()
- since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id)
-
- sync_d = defer.ensureDeferred(
- self.sync_handler.wait_for_sync_for_user(
- create_requester(user),
- generate_sync_config(user),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=since_token,
- timeout=0,
- )
- )
-
- # This should block waiting for the presence stream to update
- self.pump()
- self.assertFalse(sync_d.called)
-
- # Marking the stream ID as persisted should unblock the request.
- self.get_success(ctx_mgr.__aexit__(None, None, None))
-
- self.get_success(sync_d, by=1.0)
-
- @parameterized.expand(
- [(key,) for key in StreamKeyType.__members__.values()],
- name_func=lambda func, _, param: f"{func.__name__}_{param.args[0].name}",
- )
- def test_wait_for_invalid_future_sync_token(
- self, stream_key: StreamKeyType
- ) -> None:
- """Like the previous test, except we give a token that has a stream
- position ahead of what is in the DB, i.e. its invalid and we shouldn't
- wait for the stream to advance (as it may never do so).
-
- This can happen due to older versions of Synapse giving out stream
- positions without persisting them in the DB, and so on restart the
- stream would get reset back to an older position.
- """
- user = self.register_user("alice", "password")
-
- # Create a token and advance one of the streams.
- current_token = self.hs.get_event_sources().get_current_token()
- token_value = current_token.get_field(stream_key)
-
- # How we advance the streams depends on the type.
- if isinstance(token_value, int):
- since_token = current_token.copy_and_advance(stream_key, token_value + 1)
- elif isinstance(token_value, MultiWriterStreamToken):
- since_token = current_token.copy_and_advance(
- stream_key, MultiWriterStreamToken(stream=token_value.stream + 1)
- )
- elif isinstance(token_value, RoomStreamToken):
- since_token = current_token.copy_and_advance(
- stream_key, RoomStreamToken(stream=token_value.stream + 1)
- )
- else:
- raise Exception("Unreachable")
-
- sync_d = defer.ensureDeferred(
- self.sync_handler.wait_for_sync_for_user(
- create_requester(user),
- generate_sync_config(user),
- sync_version=SyncVersion.SYNC_V2,
- request_key=generate_request_key(),
- since_token=since_token,
- timeout=0,
- )
- )
-
- # We should return without waiting for the presence stream to advance.
- self.get_success(sync_d)
+_request_key = 0
def generate_sync_config(
- user_id: str,
- device_id: Optional[str] = "device_id",
- filter_collection: Optional[FilterCollection] = None,
+ user_id: str, device_id: Optional[str] = "device_id"
) -> SyncConfig:
- """Generate a sync config (with a unique request key).
-
- Args:
- user_id: user who is syncing.
- device_id: device that is syncing. Defaults to "device_id".
- filter_collection: filter to apply. Defaults to the default filter (ie,
- return everything, with a default limit)
- """
- if filter_collection is None:
- filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
-
+ """Generate a sync config (with a unique request key)."""
+ global _request_key
+ _request_key += 1
return SyncConfig(
user=UserID.from_string(user_id),
- filter_collection=filter_collection,
+ filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
is_guest=False,
+ request_key=("request_key", _request_key),
device_id=device_id,
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 9d8960315f..c410fe7cc4 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -32,7 +31,7 @@ from twisted.web.resource import Resource
from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
-from synapse.handlers.typing import FORGET_TIMEOUT, TypingWriterHandler
+from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
@@ -501,54 +500,3 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
}
],
)
-
- def test_prune_typing_replication(self) -> None:
- """Regression test for `get_all_typing_updates` breaking when we prune
- old updates
- """
- self.room_members = [U_APPLE, U_BANANA]
-
- instance_name = self.hs.get_instance_name()
-
- self.get_success(
- self.handler.started_typing(
- target_user=U_APPLE,
- requester=create_requester(U_APPLE),
- room_id=ROOM_ID,
- timeout=10000,
- )
- )
-
- rows, _, _ = self.get_success(
- self.handler.get_all_typing_updates(
- instance_name=instance_name,
- last_id=0,
- current_id=self.handler.get_current_token(),
- limit=100,
- )
- )
- self.assertEqual(rows, [(1, [ROOM_ID, [U_APPLE.to_string()]])])
-
- self.reactor.advance(20000)
-
- rows, _, _ = self.get_success(
- self.handler.get_all_typing_updates(
- instance_name=instance_name,
- last_id=1,
- current_id=self.handler.get_current_token(),
- limit=100,
- )
- )
- self.assertEqual(rows, [(2, [ROOM_ID, []])])
-
- self.reactor.advance(FORGET_TIMEOUT)
-
- rows, _, _ = self.get_success(
- self.handler.get_all_typing_updates(
- instance_name=instance_name,
- last_id=1,
- current_id=self.handler.get_current_token(),
- limit=100,
- )
- )
- self.assertEqual(rows, [])
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 878d9683b6..77c6cac449 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -1061,45 +1061,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
{alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
)
- def test_search_punctuation(self) -> None:
- """Test that you can search for a user that includes punctuation"""
-
- searching_user = self.register_user("searcher", "password")
- searching_user_tok = self.login("searcher", "password")
-
- room_id = self.helper.create_room_as(
- searching_user,
- room_version=RoomVersions.V1.identifier,
- tok=searching_user_tok,
- )
-
- # We want to test searching for users of the form e.g. "user-1", with
- # various punctuation. We also test both where the prefix is numeric and
- # alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1".
- i = 1
- for char in ["-", ".", "_"]:
- for use_numeric in [False, True]:
- if use_numeric:
- prefix1 = f"{i}"
- prefix2 = f"{i+1}"
- else:
- prefix1 = f"a{i}"
- prefix2 = f"a{i+1}"
-
- local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
- local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
-
- self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
- self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
-
- results = self.get_success(
- self.handler.search_users(searching_user, local_user_1, 20)
- )["results"]
- received_user_id_ordering = [result["user_id"] for result in results]
- self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1])
-
- i += 2
-
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 6e9a15c8ee..8fec3d6bb7 100644
--- a/tests/handlers/test_worker_lock.py
+++ b/tests/handlers/test_worker_lock.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -27,7 +26,6 @@ from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.utils import test_timeout
class WorkerLockTestCase(unittest.HomeserverTestCase):
@@ -51,28 +49,6 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
self.get_success(d2)
self.get_success(lock2.__aexit__(None, None, None))
- def test_lock_contention(self) -> None:
- """Test lock contention when a lot of locks wait on a single worker"""
-
- # It takes around 0.5s on a 5+ years old laptop
- with test_timeout(5):
- nb_locks = 500
- d = self._take_locks(nb_locks)
- self.assertEqual(self.get_success(d), nb_locks)
-
- async def _take_locks(self, nb_locks: int) -> int:
- locks = [
- self.hs.get_worker_locks_handler().acquire_lock("test_lock", "")
- for _ in range(nb_locks)
- ]
-
- nb_locks_taken = 0
- for lock in locks:
- async with lock:
- nb_locks_taken += 1
-
- return nb_locks_taken
-
class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase):
def prepare(
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 8e8621e348..4233b1f2cd 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
index dab387a504..3d833a2e44 100644
--- a/tests/http/server/__init__.py
+++ b/tests/http/server/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 731b0c4e59..c1b5f2e3b9 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 721917f957..ff507c7a0f 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -37,155 +36,18 @@ from synapse.http.client import (
BlocklistingAgentWrapper,
BlocklistingReactorWrapper,
BodyExceededMaxSize,
- MultipartResponse,
_DiscardBodyWithMaxSizeProtocol,
- _MultipartParserProtocol,
read_body_with_max_size,
- read_multipart_response,
)
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
-class ReadMultipartResponseTests(TestCase):
- data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
- data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
-
- redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
-
- def _build_multipart_response(
- self, response_length: Union[int, str], max_length: int
- ) -> Tuple[
- BytesIO,
- "Deferred[MultipartResponse]",
- _MultipartParserProtocol,
- ]:
- """Start reading the body, returns the response, result and proto"""
- response = Mock(length=response_length)
- result = BytesIO()
- boundary = "6067d4698f8d40a0a794ea7d7379d53a"
- deferred = read_multipart_response(response, result, boundary, max_length)
-
- # Fish the protocol out of the response.
- protocol = response.deliverBody.call_args[0][0]
- protocol.transport = Mock()
-
- return result, deferred, protocol
-
- def _assert_error(
- self,
- deferred: "Deferred[MultipartResponse]",
- protocol: _MultipartParserProtocol,
- ) -> None:
- """Ensure that the expected error is received."""
- assert isinstance(deferred.result, Failure)
- self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
- assert protocol.transport is not None
- # type-ignore: presumably abortConnection has been replaced with a Mock.
- protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
-
- def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
- """Ensure that the error in the Deferred is handled gracefully."""
- called = [False]
-
- def errback(f: Failure) -> None:
- called[0] = True
-
- deferred.addErrback(errback)
- self.assertTrue(called[0])
-
- def test_parse_file(self) -> None:
- """
- Check that a multipart response containing a file is properly parsed
- into the json/file parts, and the json and file are properly captured
- """
- result, deferred, protocol = self._build_multipart_response(249, 250)
-
- # Start sending data.
- protocol.dataReceived(self.data1)
- protocol.dataReceived(self.data2)
- # Close the connection.
- protocol.connectionLost(Failure(ResponseDone()))
-
- multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
-
- self.assertEqual(multipart_response.json, b"{}")
- self.assertEqual(result.getvalue(), b"file_to_stream")
- self.assertEqual(multipart_response.length, len(b"file_to_stream"))
- self.assertEqual(multipart_response.content_type, b"text/plain")
- self.assertEqual(
- multipart_response.disposition, b"inline; filename=test_upload"
- )
-
- def test_parse_redirect(self) -> None:
- """
- check that a multipart response containing a redirect is properly parsed and redirect url is
- returned
- """
- result, deferred, protocol = self._build_multipart_response(249, 250)
-
- # Start sending data.
- protocol.dataReceived(self.redirect_data)
- # Close the connection.
- protocol.connectionLost(Failure(ResponseDone()))
-
- multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
-
- self.assertEqual(multipart_response.json, b"{}")
- self.assertEqual(result.getvalue(), b"")
- self.assertEqual(
- multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
- )
-
- def test_too_large(self) -> None:
- """A response which is too large raises an exception."""
- result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
-
- # Start sending data.
- protocol.dataReceived(self.data1)
-
- self.assertEqual(result.getvalue(), b"file_")
- self._assert_error(deferred, protocol)
- self._cleanup_error(deferred)
-
- def test_additional_data(self) -> None:
- """A connection can receive data after being closed."""
- result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
-
- # Start sending data.
- protocol.dataReceived(self.data1)
- self._assert_error(deferred, protocol)
-
- # More data might have come in.
- protocol.dataReceived(self.data2)
-
- self.assertEqual(result.getvalue(), b"file_")
- self._assert_error(deferred, protocol)
- self._cleanup_error(deferred)
-
- def test_content_length(self) -> None:
- """The body shouldn't be read (at all) if the Content-Length header is too large."""
- result, deferred, protocol = self._build_multipart_response(250, 1)
-
- # Deferred shouldn't be called yet.
- self.assertFalse(deferred.called)
-
- # Start sending data.
- protocol.dataReceived(self.data1)
- self._assert_error(deferred, protocol)
- self._cleanup_error(deferred)
-
- # The data is never consumed.
- self.assertEqual(result.getvalue(), b"")
-
-
class ReadBodyWithMaxSizeTests(TestCase):
- def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
- BytesIO,
- "Deferred[int]",
- _DiscardBodyWithMaxSizeProtocol,
- ]:
+ def _build_response(
+ self, length: Union[int, str] = UNKNOWN_LENGTH
+ ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py
index 5895270494..4b2355536a 100644
--- a/tests/http/test_proxy.py
+++ b/tests/http/test_proxy.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index f71e4c2b8f..8fc66aee26 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 18af2735fe..d16dc419f3 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index b7806fa947..bfcbe0fb10 100644
--- a/tests/http/test_simple_client.py
+++ b/tests/http/test_simple_client.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index bfa26a329c..3c54569e9a 100644
--- a/tests/http/test_site.py
+++ b/tests/http/test_site.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index f8f2260c2e..13b2754f7b 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index c7ef2bd7a4..38144a9a45 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index f5412ac6e2..80e228b829 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index ff85e067b7..4460b050d9 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/media/__init__.py b/tests/media/__init__.py
index b401d78ef1..3d833a2e44 100644
--- a/tests/media/__init__.py
+++ b/tests/media/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/media/test_filepath.py b/tests/media/test_filepath.py
index cd21b369b4..9d8989ce34 100644
--- a/tests/media/test_filepath.py
+++ b/tests/media/test_filepath.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/media/test_html_preview.py b/tests/media/test_html_preview.py
index d3f1e8833a..0ab33e87b4 100644
--- a/tests/media/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py
index 417d17ebd2..7f613f351b 100644
--- a/tests/media/test_media_retention.py
+++ b/tests/media/test_media_retention.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index e55001fb40..a42383bcb6 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -24,7 +23,7 @@ import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import Mock
from urllib import parse
import attr
@@ -36,12 +35,9 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
-from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException
-from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
@@ -49,11 +45,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend
-from synapse.media.thumbnailer import ThumbnailProvider
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.rest import admin
-from synapse.rest.client import login, media
+from synapse.rest.client import login
+from synapse.rest.media.thumbnail_resource import ThumbnailResource
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
@@ -61,7 +57,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import SMALL_PNG
-from tests.unittest import override_config
from tests.utils import default_config
@@ -128,7 +123,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
@attr.s(auto_attribs=True, slots=True, frozen=True)
-class TestImage:
+class _TestImage:
"""An image for testing thumbnailing with the expected results
Attributes:
@@ -157,54 +152,68 @@ class TestImage:
is_inline: bool = True
-small_png = TestImage(
- SMALL_PNG,
- b"image/png",
- b".png",
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000020000000200806"
- b"000000737a7af40000001a49444154789cedc101010000008220"
- b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
- b"44ae426082"
- ),
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000d49444154789c636060606000000005"
- b"0001a5f645400000000049454e44ae426082"
- ),
-)
-
-small_png_with_transparency = TestImage(
- unhexlify(
- b"89504e470d0a1a0a0000000d49484452000000010000000101000"
- b"00000376ef9240000000274524e5300010194fdae0000000a4944"
- b"4154789c636800000082008177cd72b60000000049454e44ae426"
- b"082"
- ),
- b"image/png",
- b".png",
- # Note that we don't check the output since it varies across
- # different versions of Pillow.
-)
-
-small_lossless_webp = TestImage(
- unhexlify(
- b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
- ),
- b"image/webp",
- b".webp",
-)
-
-empty_file = TestImage(
- b"",
- b"image/gif",
- b".gif",
- expected_found=False,
- unable_to_thumbnail=True,
-)
-
-SVG = TestImage(
- b"""<?xml version="1.0"?>
+@parameterized_class(
+ ("test_image",),
+ [
+ # small png
+ (
+ _TestImage(
+ SMALL_PNG,
+ b"image/png",
+ b".png",
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+ b"000000737a7af40000001a49444154789cedc101010000008220"
+ b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+ b"44ae426082"
+ ),
+ unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000d49444154789c636060606000000005"
+ b"0001a5f645400000000049454e44ae426082"
+ ),
+ ),
+ ),
+ # small png with transparency.
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d49484452000000010000000101000"
+ b"00000376ef9240000000274524e5300010194fdae0000000a4944"
+ b"4154789c636800000082008177cd72b60000000049454e44ae426"
+ b"082"
+ ),
+ b"image/png",
+ b".png",
+ # Note that we don't check the output since it varies across
+ # different versions of Pillow.
+ ),
+ ),
+ # small lossless webp
+ (
+ _TestImage(
+ unhexlify(
+ b"524946461a000000574542505650384c0d0000002f0000001007"
+ b"1011118888fe0700"
+ ),
+ b"image/webp",
+ b".webp",
+ ),
+ ),
+ # an empty file
+ (
+ _TestImage(
+ b"",
+ b"image/gif",
+ b".gif",
+ expected_found=False,
+ unable_to_thumbnail=True,
+ ),
+ ),
+ # An SVG.
+ (
+ _TestImage(
+ b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@@ -213,26 +222,17 @@ SVG = TestImage(
<circle cx="100" cy="100" r="50" stroke="black"
stroke-width="5" fill="red" />
</svg>""",
- b"image/svg",
- b".svg",
- expected_found=False,
- unable_to_thumbnail=True,
- is_inline=False,
+ b"image/svg",
+ b".svg",
+ expected_found=False,
+ unable_to_thumbnail=True,
+ is_inline=False,
+ ),
+ ),
+ ],
)
-test_images = [
- small_png,
- small_png_with_transparency,
- small_lossless_webp,
- empty_file,
- SVG,
-]
-input_values = [(x,) for x in test_images]
-
-
-@parameterized_class(("test_image",), input_values)
class MediaRepoTests(unittest.HomeserverTestCase):
- servlets = [media.register_servlets]
- test_image: ClassVar[TestImage]
+ test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
@@ -250,11 +250,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
destination: str,
path: str,
output_stream: BinaryIO,
- download_ratelimiter: Ratelimiter,
- ip_address: Any,
- max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
+ max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
@@ -503,7 +501,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = self.make_request(
"GET",
- f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
+ f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -531,7 +529,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
+ f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -572,6 +570,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
await_result=False,
)
self.pump()
+
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
@@ -580,6 +579,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
+
if expected_found:
self.assertEqual(channel.code, 200)
@@ -624,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
content_type = self.test_image.content_type.decode()
media_repo = self.hs.get_media_repository()
- thumbnail_provider = ThumbnailProvider(
+ thumbnail_resouce = ThumbnailResource(
self.hs, media_repo, media_repo.media_storage
)
self.assertIsNotNone(
- thumbnail_provider._select_thumbnail(
+ thumbnail_resouce._select_thumbnail(
desired_width=desired_size,
desired_height=desired_size,
desired_method=method,
@@ -878,249 +878,3 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
tok=self.tok,
expect_code=400,
)
-
-
-class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
- config["media_store_path"] = self.media_store_path
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
-
- config["media_storage_providers"] = [provider_config]
-
- return self.setup_test_homeserver(config=config)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.repo = hs.get_media_repository()
- self.client = hs.get_federation_http_client()
- self.store = hs.get_datastores().main
-
- def create_resource_dict(self) -> Dict[str, Resource]:
- # We need to manually set the resource tree to include media, the
- # default only does `/_matrix/client` APIs.
- return {"/_matrix/media": self.hs.get_media_repository_resource()}
-
- # mock actually reading file body
- def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
- d: Deferred = defer.Deferred()
- d.callback(31457280)
- return d
-
- def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
- d: Deferred = defer.Deferred()
- d.callback(52428800)
- return d
-
- @patch(
- "synapse.http.matrixfederationclient.read_body_with_max_size",
- read_body_with_max_size_30MiB,
- )
- def test_download_ratelimit_default(self) -> None:
- """
- Test remote media download ratelimiting against default configuration - 500MB bucket
- and 87kb/second drain rate
- """
-
- # mock out actually sending the request, returns a 30MiB response
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = 31457280
- resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- # first request should go through
- channel = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
- shorthand=False,
- )
- assert channel.code == 200
-
- # next 15 should go through
- for i in range(15):
- channel2 = self.make_request(
- "GET",
- f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
- shorthand=False,
- )
- assert channel2.code == 200
-
- # 17th will hit ratelimit
- channel3 = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
- shorthand=False,
- )
- assert channel3.code == 429
-
- # however, a request from a different IP will go through
- channel4 = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
- shorthand=False,
- client_ip="187.233.230.159",
- )
- assert channel4.code == 200
-
- # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
- # 30MiB download is authorized - The last download was blocked at 503,316,480.
- # The next download will be authorized when bucket hits 492,830,720
- # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
- # needs to drain before another download will be authorized, that will take ~=
- # 2 minutes (10,485,760/89,088/60)
- self.reactor.pump([2.0 * 60.0])
-
- # enough has drained and next request goes through
- channel5 = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
- shorthand=False,
- )
- assert channel5.code == 200
-
- @override_config(
- {
- "remote_media_download_per_second": "50M",
- "remote_media_download_burst_count": "50M",
- }
- )
- @patch(
- "synapse.http.matrixfederationclient.read_body_with_max_size",
- read_body_with_max_size_50MiB,
- )
- def test_download_rate_limit_config(self) -> None:
- """
- Test that download rate limit config options are correctly picked up and applied
- """
-
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = 52428800
- resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- # first request should go through
- channel = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
- shorthand=False,
- )
- assert channel.code == 200
-
- # immediate second request should fail
- channel = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
- shorthand=False,
- )
- assert channel.code == 429
-
- # advance half a second
- self.reactor.pump([0.5])
-
- # request still fails
- channel = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
- shorthand=False,
- )
- assert channel.code == 429
-
- # advance another half second
- self.reactor.pump([0.5])
-
- # enough has drained from bucket and request is successful
- channel = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
- shorthand=False,
- )
- assert channel.code == 200
-
- @override_config({"remote_media_download_burst_count": "87M"})
- @patch(
- "synapse.http.matrixfederationclient.read_body_with_max_size",
- read_body_with_max_size_30MiB,
- )
- def test_download_ratelimit_unknown_length(self) -> None:
- """
- Test that if no content-length is provided, ratelimit will still be applied after
- download once length is known
- """
-
- # mock out actually sending the request
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = UNKNOWN_LENGTH
- resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- # 3 requests should go through (note 3rd one would technically violate ratelimit but
- # is applied *after* download - the next one will be ratelimited)
- for i in range(3):
- channel = self.make_request(
- "GET",
- f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
- shorthand=False,
- )
- assert channel.code == 200
-
- # 4th will hit ratelimit
- channel2 = self.make_request(
- "GET",
- "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
- shorthand=False,
- )
- assert channel2.code == 429
-
- @override_config({"max_upload_size": "29M"})
- @patch(
- "synapse.http.matrixfederationclient.read_body_with_max_size",
- read_body_with_max_size_30MiB,
- )
- def test_max_download_respected(self) -> None:
- """
- Test that the max download size is enforced - note that max download size is determined
- by the max_upload_size
- """
-
- # mock out actually sending the request
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = 31457280
- resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- channel = self.make_request(
- "GET", "/_matrix/media/v3/download/remote.org/abcd", shorthand=False
- )
- assert channel.code == 502
- assert channel.json_body["errcode"] == "M_TOO_LARGE"
diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py
index 29d4580697..abcb33a3b4 100644
--- a/tests/media/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 0ae414d408..8d3aa60657 100644
--- a/tests/media/test_url_previewer.py
+++ b/tests/media/test_url_previewer.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py
index 80f24814e8..c989098685 100644
--- a/tests/metrics/test_metrics.py
+++ b/tests/metrics/test_metrics.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py
index fd87eaffd0..18cd53673b 100644
--- a/tests/module_api/test_account_data_manager.py
+++ b/tests/module_api/test_account_data_manager.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index b6ba472d7d..ce142e919f 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -688,7 +687,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
channel = self.make_request(
"GET",
- "/notifications",
+ "/notifications?from=",
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/module_api/test_event_unsigned_addition.py b/tests/module_api/test_event_unsigned_addition.py
index c429eff4d6..39ab52174c 100644
--- a/tests/module_api/test_event_unsigned_addition.py
+++ b/tests/module_api/test_event_unsigned_addition.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index fc73f3dc2a..e0d6d027bf 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e0aab1c046..c927a73fa6 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -205,24 +205,8 @@ class EmailPusherTests(HomeserverTestCase):
# Multipart: plain text, base 64 encoded; html, base 64 encoded
multipart_msg = email.message_from_bytes(msg)
-
- # Extract the text (non-HTML) portion of the multipart Message,
- # as a Message.
- txt_message = multipart_msg.get_payload(i=0)
- assert isinstance(txt_message, email.message.Message)
-
- # Extract the actual bytes from the Message object, and decode them to a `str`.
- txt_bytes = txt_message.get_payload(decode=True)
- assert isinstance(txt_bytes, bytes)
- txt = txt_bytes.decode()
-
- # Do the same for the HTML portion of the multipart Message.
- html_message = multipart_msg.get_payload(i=1)
- assert isinstance(html_message, email.message.Message)
- html_bytes = html_message.get_payload(decode=True)
- assert isinstance(html_bytes, bytes)
- html = html_bytes.decode()
-
+ txt = multipart_msg.get_payload()[0].get_payload(decode=True).decode()
+ html = multipart_msg.get_payload()[1].get_payload(decode=True).decode()
self.assertIn("/_synapse/client/unsubscribe", txt)
self.assertIn("/_synapse/client/unsubscribe", html)
@@ -363,17 +347,12 @@ class EmailPusherTests(HomeserverTestCase):
# That email should contain the room's avatar
msg: bytes = args[5]
# Multipart: plain text, base 64 encoded; html, base 64 encoded
-
- # Extract the html Message object from the Multipart Message.
- # We need the asserts to convince mypy that this is OK.
- html_message = email.message_from_bytes(msg).get_payload(i=1)
- assert isinstance(html_message, email.message.Message)
-
- # Extract the `bytes` from the html Message object, and decode to a `str`.
- html = html_message.get_payload(decode=True)
- assert isinstance(html, bytes)
- html = html.decode()
-
+ html = (
+ email.message_from_bytes(msg)
+ .get_payload()[1]
+ .get_payload(decode=True)
+ .decode()
+ )
self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
def test_empty_room(self) -> None:
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index bcca472617..dce00d8b7f 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -26,8 +26,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
-from synapse.rest.admin.experimental_features import ExperimentalFeature
-from synapse.rest.client import login, push_rule, pusher, receipts, room, versions
+from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -43,7 +42,6 @@ class HTTPPusherTests(HomeserverTestCase):
receipts.register_servlets,
push_rule.register_servlets,
pusher.register_servlets,
- versions.register_servlets,
]
user_id = True
hijack_auth = False
@@ -971,84 +969,6 @@ class HTTPPusherTests(HomeserverTestCase):
lookup_result.device_id,
)
- def test_device_id_feature_flag(self) -> None:
- """Tests that a pusher created with a given device ID shows that device ID in
- GET /pushers requests when feature is enabled for the user
- """
- user_id = self.register_user("user", "pass")
- access_token = self.login("user", "pass")
-
- # We create the pusher with an HTTP request rather than with
- # _make_user_with_pusher so that we can test the device ID is correctly set when
- # creating a pusher via an API call.
- self.make_request(
- method="POST",
- path="/pushers/set",
- content={
- "kind": "http",
- "app_id": "m.http",
- "app_display_name": "HTTP Push Notifications",
- "device_display_name": "pushy push",
- "pushkey": "a@example.com",
- "lang": "en",
- "data": {"url": "http://example.com/_matrix/push/v1/notify"},
- },
- access_token=access_token,
- )
-
- # Look up the user info for the access token so we can compare the device ID.
- store = self.hs.get_datastores().main
- lookup_result = self.get_success(store.get_user_by_access_token(access_token))
- assert lookup_result is not None
-
- # Check field is not there before we enable the feature flag
- channel = self.make_request("GET", "/pushers", access_token=access_token)
- self.assertEqual(channel.code, 200)
- self.assertEqual(len(channel.json_body["pushers"]), 1)
- self.assertNotIn(
- "org.matrix.msc3881.device_id", channel.json_body["pushers"][0]
- )
-
- self.get_success(
- store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True})
- )
-
- # Get the user's devices and check it has the correct device ID.
- channel = self.make_request("GET", "/pushers", access_token=access_token)
- self.assertEqual(channel.code, 200)
- self.assertEqual(len(channel.json_body["pushers"]), 1)
- self.assertEqual(
- channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
- lookup_result.device_id,
- )
-
- def test_msc3881_client_versions_flag(self) -> None:
- """Tests that MSC3881 only appears in /versions if user has it enabled."""
-
- user_id = self.register_user("user", "pass")
- access_token = self.login("user", "pass")
-
- # Check feature is disabled in /versions
- channel = self.make_request(
- "GET", "/_matrix/client/versions", access_token=access_token
- )
- self.assertEqual(channel.code, 200)
- self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"])
-
- # Enable feature for user
- self.get_success(
- self.hs.get_datastores().main.set_features_for_user(
- user_id, {ExperimentalFeature.MSC3881: True}
- )
- )
-
- # Check feature is now enabled in /versions for user
- channel = self.make_request(
- "GET", "/_matrix/client/versions", access_token=access_token
- )
- self.assertEqual(channel.code, 200)
- self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"])
-
@override_config({"push": {"jitter_delay": "10s"}})
def test_jitter(self) -> None:
"""Tests that enabling jitter actually delays sending push."""
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
index bd42fc0580..1f41c23e73 100644
--- a/tests/push/test_presentable_names.py
+++ b/tests/push/test_presentable_names.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 420fbea998..b129332bb7 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/__init__.py b/tests/replication/__init__.py
index 587ee42067..3d833a2e44 100644
--- a/tests/replication/__init__.py
+++ b/tests/replication/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 8437da1cdd..d2220f8195 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -495,9 +495,9 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self) -> None:
- self._subscribers_by_channel: Dict[bytes, Set["FakeRedisPubSubProtocol"]] = (
- defaultdict(set)
- )
+ self._subscribers_by_channel: Dict[
+ bytes, Set["FakeRedisPubSubProtocol"]
+ ] = defaultdict(set)
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py
index dab387a504..3d833a2e44 100644
--- a/tests/replication/http/__init__.py
+++ b/tests/replication/http/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 2eaad3707a..f8a07b34e8 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/storage/__init__.py b/tests/replication/storage/__init__.py
index 587ee42067..3d833a2e44 100644
--- a/tests/replication/storage/__init__.py
+++ b/tests/replication/storage/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py
index 27dff0034f..357ff79ed3 100644
--- a/tests/replication/storage/_base.py
+++ b/tests/replication/storage/_base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
index 1afe523d02..69350963f2 100644
--- a/tests/replication/storage/test_events.py
+++ b/tests/replication/storage/test_events.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -30,16 +29,19 @@ from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
+from synapse.handlers.room import RoomEventSource
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import (
NotifCounts,
RoomNotifCounts,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
from synapse.util import Clock
+from tests.server import FakeTransport
+
from ._base import BaseWorkerStoreTestCase
USER_ID = "@feeling:test"
@@ -138,7 +140,6 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
- assert event.internal_metadata.instance_name is not None
assert event.internal_metadata.stream_ordering is not None
self.replicate()
@@ -152,10 +153,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
USER_ID,
"invite",
event.event_id,
- PersistedEventPosition(
- event.internal_metadata.instance_name,
- event.internal_metadata.stream_ordering,
- ),
+ event.internal_metadata.stream_ordering,
RoomVersions.V1.identifier,
)
],
@@ -218,6 +216,122 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
),
)
+ def test_get_rooms_for_user_with_stream_ordering(self) -> None:
+ """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
+ by rows in the events stream
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ j2 = self.persist(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ assert j2.internal_metadata.stream_ordering is not None
+ self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
+ self.check(
+ "get_rooms_for_user_with_stream_ordering",
+ (USER_ID_2,),
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
+
+ def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+ self,
+ ) -> None:
+ """Check that current_state invalidation happens correctly with multiple events
+ in the persistence batch.
+
+ This test attempts to reproduce a race condition between the event persistence
+ loop and a worker-based Sync handler.
+
+ The problem occurred when the master persisted several events in one batch. It
+ only updates the current_state at the end of each batch, so the obvious thing
+ to do is then to issue a current_state_delta stream update corresponding to the
+ last stream_id in the batch.
+
+ However, that raises the possibility that a worker will see the replication
+ notification for a join event before the current_state caches are invalidated.
+
+ The test involves:
+ * creating a join and a message event for a user, and persisting them in the
+ same batch
+
+ * controlling the replication stream so that updates are sent gradually
+
+ * between each bunch of replication updates, check that we see a consistent
+ snapshot of the state.
+ """
+ self.persist(type="m.room.create", key="", creator=USER_ID)
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
+ self.replicate()
+ self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
+
+ # limit the replication rate
+ repl_transport = self._server_transport
+ assert isinstance(repl_transport, FakeTransport)
+ repl_transport.autoflush = False
+
+ # build the join and message events and persist them in the same batch.
+ logger.info("----- build test events ------")
+ j2, j2ctx = self.build_event(
+ type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
+ )
+ msg, msgctx = self.build_event()
+ self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
+ self.replicate()
+ assert j2.internal_metadata.stream_ordering is not None
+
+ event_source = RoomEventSource(self.hs)
+ event_source.store = self.worker_store
+ current_token = event_source.get_current_key()
+
+ # gradually stream out the replication
+ while repl_transport.buffer:
+ logger.info("------ flush ------")
+ repl_transport.flush(30)
+ self.pump(0)
+
+ prev_token = current_token
+ current_token = event_source.get_current_key()
+
+ # attempt to replicate the behaviour of the sync handler.
+ #
+ # First, we get a list of the rooms we are joined to
+ joined_rooms = self.get_success(
+ self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
+ )
+
+ # Then, we get a list of the events since the last sync
+ membership_changes = self.get_success(
+ self.worker_store.get_membership_changes_for_user(
+ USER_ID_2, prev_token, current_token
+ )
+ )
+
+ logger.info(
+ "%s->%s: joined_rooms=%r membership_changes=%r",
+ prev_token,
+ current_token,
+ joined_rooms,
+ membership_changes,
+ )
+
+ # the membership change is only any use to us if the room is in the
+ # joined_rooms list.
+ if membership_changes:
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
+ self.assertEqual(
+ joined_rooms,
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
+
event_id = 0
def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 6dea29ae15..9d934a8dfe 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index abd4e0b8cd..e3d4c01aae 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
index cb07e93d6b..e9e9590396 100644
--- a/tests/replication/tcp/streams/test_to_device.py
+++ b/tests/replication/tcp/streams/test_to_device.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index b1c2f5b03b..198045dd86 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index 87a2365ae3..10efc00679 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index a8eb7fc523..ecdf8e6679 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 7820de8acc..517e30d160 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 298c662555..fb8e8ec3e5 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 14c9483f2b..2219ade77e 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 4429d0f4e2..6ebd54d8f4 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
index 1e7183edaa..bbf3ad371d 100644
--- a/tests/replication/test_module_cache_invalidation.py
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 6fc4600c41..70c10d126d 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -28,7 +27,7 @@ from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
-from synapse.rest.client import login, media
+from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -255,238 +254,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
-class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
- """Checks running multiple media repos work correctly using autheticated media paths"""
-
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- media.register_servlets,
- ]
-
- file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user_id = self.register_user("user", "pass")
- self.access_token = self.login("user", "pass")
-
- self.reactor.lookups["example.com"] = "1.2.3.4"
-
- def default_config(self) -> dict:
- conf = super().default_config()
- conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
- return conf
-
- def make_worker_hs(
- self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
- ) -> HomeServer:
- worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
- # Force the media paths onto the replication resource.
- worker_hs.get_media_repository_resource().register_servlets(
- self._hs_to_site[worker_hs].resource, worker_hs
- )
- return worker_hs
-
- def _get_media_req(
- self, hs: HomeServer, target: str, media_id: str
- ) -> Tuple[FakeChannel, Request]:
- """Request some remote media from the given HS by calling the download
- API.
-
- This then triggers an outbound request from the HS to the target.
-
- Returns:
- The channel for the *client* request and the *outbound* request for
- the media which the caller should respond to.
- """
- channel = make_request(
- self.reactor,
- self._hs_to_site[hs],
- "GET",
- f"/_matrix/client/v1/media/download/{target}/{media_id}",
- shorthand=False,
- access_token=self.access_token,
- await_result=False,
- )
- self.pump()
-
- clients = self.reactor.tcpClients
- self.assertGreaterEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
-
- # build the test server
- server_factory = Factory.forProtocol(HTTPChannel)
- # Request.finish expects the factory to have a 'log' method.
- server_factory.log = _log_request
-
- server_tls_protocol = wrap_server_factory_for_tls(
- server_factory, self.reactor, sanlist=[b"DNS:example.com"]
- ).buildProtocol(None)
-
- # now, tell the client protocol factory to build the client protocol (it will be a
- # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
- # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
- # a FakeTransport.
- #
- # Normally this would be done by the TCP socket code in Twisted, but we are
- # stubbing that out here.
- client_protocol = client_factory.buildProtocol(None)
- client_protocol.makeConnection(
- FakeTransport(server_tls_protocol, self.reactor, client_protocol)
- )
-
- # tell the server tls protocol to send its stuff back to the client, too
- server_tls_protocol.makeConnection(
- FakeTransport(client_protocol, self.reactor, server_tls_protocol)
- )
-
- # fish the test server back out of the server-side TLS protocol.
- http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
-
- # give the reactor a pump to get the TLS juices flowing.
- self.reactor.pump((0.1,))
-
- self.assertEqual(len(http_server.requests), 1)
- request = http_server.requests[0]
-
- self.assertEqual(request.method, b"GET")
- self.assertEqual(
- request.path,
- f"/_matrix/federation/v1/media/download/{media_id}".encode(),
- )
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
- )
-
- return channel, request
-
- def test_basic(self) -> None:
- """Test basic fetching of remote media from a single worker."""
- hs1 = self.make_worker_hs("synapse.app.generic_worker")
-
- channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
-
- request.setResponseCode(200)
- request.responseHeaders.setRawHeaders(
- b"Content-Type",
- ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
- )
- request.write(self.file_data)
- request.finish()
-
- self.pump(0.1)
-
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.result["body"], b"file_to_stream")
-
- def test_download_simple_file_race(self) -> None:
- """Test that fetching remote media from two different processes at the
- same time works.
- """
- hs1 = self.make_worker_hs("synapse.app.generic_worker")
- hs2 = self.make_worker_hs("synapse.app.generic_worker")
-
- start_count = self._count_remote_media()
-
- # Make two requests without responding to the outbound media requests.
- channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
- channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
-
- # Respond to the first outbound media request and check that the client
- # request is successful
- request1.setResponseCode(200)
- request1.responseHeaders.setRawHeaders(
- b"Content-Type",
- ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
- )
- request1.write(self.file_data)
- request1.finish()
-
- self.pump(0.1)
-
- self.assertEqual(channel1.code, 200, channel1.result["body"])
- self.assertEqual(channel1.result["body"], b"file_to_stream")
-
- # Now respond to the second with the same content.
- request2.setResponseCode(200)
- request2.responseHeaders.setRawHeaders(
- b"Content-Type",
- ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
- )
- request2.write(self.file_data)
- request2.finish()
-
- self.pump(0.1)
-
- self.assertEqual(channel2.code, 200, channel2.result["body"])
- self.assertEqual(channel2.result["body"], b"file_to_stream")
-
- # We expect only one new file to have been persisted.
- self.assertEqual(start_count + 1, self._count_remote_media())
-
- def test_download_image_race(self) -> None:
- """Test that fetching remote *images* from two different processes at
- the same time works.
-
- This checks that races generating thumbnails are handled correctly.
- """
- hs1 = self.make_worker_hs("synapse.app.generic_worker")
- hs2 = self.make_worker_hs("synapse.app.generic_worker")
-
- start_count = self._count_remote_thumbnails()
-
- channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
- channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
-
- request1.setResponseCode(200)
- request1.responseHeaders.setRawHeaders(
- b"Content-Type",
- ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
- )
- img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n"
- request1.write(img_data)
- request1.write(SMALL_PNG)
- request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
- request1.finish()
-
- self.pump(0.1)
-
- self.assertEqual(channel1.code, 200, channel1.result["body"])
- self.assertEqual(channel1.result["body"], SMALL_PNG)
-
- request2.setResponseCode(200)
- request2.responseHeaders.setRawHeaders(
- b"Content-Type",
- ["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
- )
- request2.write(img_data)
- request2.write(SMALL_PNG)
- request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
- request2.finish()
-
- self.pump(0.1)
-
- self.assertEqual(channel2.code, 200, channel2.result["body"])
- self.assertEqual(channel2.result["body"], SMALL_PNG)
-
- # We expect only three new thumbnails to have been persisted.
- self.assertEqual(start_count + 3, self._count_remote_thumbnails())
-
- def _count_remote_media(self) -> int:
- """Count the number of files in our remote media directory."""
- path = os.path.join(
- self.hs.get_media_repository().primary_base_path, "remote_content"
- )
- return sum(len(files) for _, _, files in os.walk(path))
-
- def _count_remote_thumbnails(self) -> int:
- """Count the number of files in our remote thumbnails directory."""
- path = os.path.join(
- self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
- )
- return sum(len(files) for _, _, files in os.walk(path))
-
-
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 1b0bdc262a..7002521893 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index ce6ad75901..04e58cf5df 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py
index e400267819..81b319c858 100644
--- a/tests/replication/test_sharded_receipts.py
+++ b/tests/replication/test_sharded_receipts.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/__init__.py b/tests/rest/__init__.py
index 6a72062b0c..3d833a2e44 100644
--- a/tests/rest/__init__.py
+++ b/tests/rest/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6351326fff..defccd7d12 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -384,7 +383,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
"PUT",
url,
content={
- "features": {"msc3881": True},
+ "features": {"msc3026": True, "msc3881": True},
},
access_token=self.admin_user_tok,
)
@@ -401,6 +400,10 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(
True,
+ channel.json_body["features"]["msc3026"],
+ )
+ self.assertEqual(
+ True,
channel.json_body["features"]["msc3881"],
)
@@ -409,7 +412,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
url,
- content={"features": {"msc3881": False}},
+ content={"features": {"msc3026": False}},
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
@@ -425,15 +428,23 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(
False,
+ channel.json_body["features"]["msc3026"],
+ )
+ self.assertEqual(
+ True,
channel.json_body["features"]["msc3881"],
)
+ self.assertEqual(
+ False,
+ channel.json_body["features"]["msc3967"],
+ )
# test nothing blows up if you try to disable a feature that isn't already enabled
url = f"{self.url}/{self.other_user}"
channel = self.make_request(
"PUT",
url,
- content={"features": {"msc3881": False}},
+ content={"features": {"msc3026": False}},
access_token=self.admin_user_tok,
)
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index f33aada64b..caea9d4415 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index a88c77bd19..c8f6fa105a 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index feb410a11d..3e695e9700 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -24,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
-from synapse.rest.client import login, reporting, room
+from synapse.rest.client import login, report_event, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -37,7 +36,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
- reporting.register_servlets,
+ report_event.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -453,7 +452,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
- reporting.register_servlets,
+ report_event.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c2015774a1..1cdb1105eb 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -778,81 +777,20 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
self._check_fields(channel.json_body["rooms"])
- def test_room_filtering(self) -> None:
- """Tests that rooms are correctly filtered"""
-
- # Create two rooms on the homeserver. Each has a different remote homeserver
- # participating in it.
- other_destination = "other.destination.org"
- room_ids_self_dest = self._create_destination_rooms(2, destination=self.dest)
- room_ids_other_dest = self._create_destination_rooms(
- 1, destination=other_destination
- )
-
- # Ask for the rooms that `self.dest` is participating in.
- channel = self.make_request("GET", self.url, access_token=self.admin_user_tok)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Verify that we received only the rooms that `self.dest` is participating in.
- # This assertion method name is a bit misleading. It does check that both lists
- # contain the same items, and the same counts.
- self.assertCountEqual(
- [r["room_id"] for r in channel.json_body["rooms"]], room_ids_self_dest
- )
- self.assertEqual(channel.json_body["total"], len(room_ids_self_dest))
-
- # Ask for the rooms that `other_destination` is participating in.
- channel = self.make_request(
- "GET",
- self.url.replace(self.dest, other_destination),
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Verify that we received only the rooms that `other_destination` is
- # participating in.
- self.assertCountEqual(
- [r["room_id"] for r in channel.json_body["rooms"]], room_ids_other_dest
- )
- self.assertEqual(channel.json_body["total"], len(room_ids_other_dest))
-
- def _create_destination_rooms(
- self,
- number_rooms: int,
- destination: Optional[str] = None,
- ) -> List[str]:
- """
- Create the given number of rooms. The given `destination` homeserver will
- be recorded as a participant.
+ def _create_destination_rooms(self, number_rooms: int) -> None:
+ """Create a number rooms for destination
Args:
number_rooms: Number of rooms to be created
- destination: The domain of the homeserver that will be considered
- as a participant in the rooms.
-
- Returns:
- The IDs of the rooms that have been created.
"""
- room_ids = []
-
- # If no destination was provided, default to `self.dest`.
- if destination is None:
- destination = self.dest
-
for _ in range(number_rooms):
room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
- room_ids.append(room_id)
-
self.get_success(
- self.store.store_destination_rooms_entries(
- (destination,), room_id, 1234
- )
+ self.store.store_destination_rooms_entries((self.dest,), room_id, 1234)
)
- return room_ids
-
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected room attributes are present in content
diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py
index 55b822c4d0..3636ea3415 100644
--- a/tests/rest/admin/test_jwks.py
+++ b/tests/rest/admin/test_jwks.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index f378165513..ea5b4e2c12 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-# Copyright 2020 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -277,8 +275,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
- "Missing required integer query parameter before_ts",
- channel.json_body["error"],
+ "Missing integer query parameter 'before_ts'", channel.json_body["error"]
)
def test_invalid_parameter(self) -> None:
@@ -321,7 +318,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
- "Query parameter size_gt must be a positive integer.",
+ "Query parameter size_gt must be a string representing a positive integer.",
channel.json_body["error"],
)
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 67d1db8ff8..773cfa8d6a 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 Callum Brown
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 95ed736451..a511175b99 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -21,7 +20,6 @@
import json
import time
import urllib.parse
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import AsyncMock, Mock
@@ -1795,83 +1793,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
- def test_filter_public_rooms(self) -> None:
- self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=True
- )
- self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=True
- )
- self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=False
- )
-
- response = self.make_request(
- "GET",
- "/_synapse/admin/v1/rooms",
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, response.code, msg=response.json_body)
- self.assertEqual(3, response.json_body["total_rooms"])
- self.assertEqual(3, len(response.json_body["rooms"]))
-
- response = self.make_request(
- "GET",
- "/_synapse/admin/v1/rooms?public_rooms=true",
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, response.code, msg=response.json_body)
- self.assertEqual(2, response.json_body["total_rooms"])
- self.assertEqual(2, len(response.json_body["rooms"]))
-
- response = self.make_request(
- "GET",
- "/_synapse/admin/v1/rooms?public_rooms=false",
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, response.code, msg=response.json_body)
- self.assertEqual(1, response.json_body["total_rooms"])
- self.assertEqual(1, len(response.json_body["rooms"]))
-
- def test_filter_empty_rooms(self) -> None:
- self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=True
- )
- self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=True
- )
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=False
- )
- self.helper.leave(room_id, self.admin_user, tok=self.admin_user_tok)
-
- response = self.make_request(
- "GET",
- "/_synapse/admin/v1/rooms",
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, response.code, msg=response.json_body)
- self.assertEqual(3, response.json_body["total_rooms"])
- self.assertEqual(3, len(response.json_body["rooms"]))
-
- response = self.make_request(
- "GET",
- "/_synapse/admin/v1/rooms?empty_rooms=false",
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, response.code, msg=response.json_body)
- self.assertEqual(2, response.json_body["total_rooms"])
- self.assertEqual(2, len(response.json_body["rooms"]))
-
- response = self.make_request(
- "GET",
- "/_synapse/admin/v1/rooms?empty_rooms=true",
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, response.code, msg=response.json_body)
- self.assertEqual(1, response.json_body["total_rooms"])
- self.assertEqual(1, len(response.json_body["rooms"]))
-
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
@@ -2268,33 +2189,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
- def test_room_message_filter_query_validation(self) -> None:
- # Test json validation in (filter) query parameter.
- # Does not test the validity of the filter, only the json validation.
-
- # Check Get with valid json filter parameter, expect 200.
- valid_filter_str = '{"types": ["m.room.message"]}'
- channel = self.make_request(
- "GET",
- f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
- invalid_filter_str = "}}}{}"
- channel = self.make_request(
- "GET",
- f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
- self.assertEqual(
- channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
- )
-
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -2627,39 +2521,6 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
else:
self.fail("Event %s from events_after not found" % j)
- def test_room_event_context_filter_query_validation(self) -> None:
- # Test json validation in (filter) query parameter.
- # Does not test the validity of the filter, only the json validation.
-
- # Create a user with room and event_id.
- user_id = self.register_user("test", "test")
- user_tok = self.login("test", "test")
- room_id = self.helper.create_room_as(user_id, tok=user_tok)
- event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]
-
- # Check Get with valid json filter parameter, expect 200.
- valid_filter_str = '{"types": ["m.room.message"]}'
- channel = self.make_request(
- "GET",
- f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
- invalid_filter_str = "}}}{}"
- channel = self.make_request(
- "GET",
- f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
- self.assertEqual(
- channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
- )
-
class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2a1e42bbc8..ce5e3a5c1f 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5f60e19e56..0b30e8c65f 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-# Copyright 2020 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 16bb4349f5..61cbac2332 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -37,7 +36,6 @@ from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
-from synapse.rest import admin
from synapse.rest.client import (
devices,
login,
@@ -504,7 +502,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"{self.url}?deactivated=true",
+ self.url + "?deactivated=true",
{},
access_token=self.admin_user_tok,
)
@@ -983,56 +981,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(1, channel.json_body["total"])
self.assertFalse(channel.json_body["users"][0]["admin"])
- def test_filter_deactivated_users(self) -> None:
- """
- Tests whether the various values of the query parameter `deactivated` lead to the
- expected result set.
- """
- users_url_v3 = self.url.replace("v2", "v3")
-
- # Register an additional non admin user
- user_id = self.register_user("user", "pass", admin=False)
-
- # Deactivate that user, requesting erasure.
- deactivate_account_handler = self.hs.get_deactivate_account_handler()
- self.get_success(
- deactivate_account_handler.deactivate_account(
- user_id, erase_data=True, requester=create_requester(user_id)
- )
- )
-
- # Query all users
- channel = self.make_request(
- "GET",
- users_url_v3,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, channel.result)
- self.assertEqual(2, channel.json_body["total"])
-
- # Query deactivated users
- channel = self.make_request(
- "GET",
- f"{users_url_v3}?deactivated=true",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, channel.result)
- self.assertEqual(1, channel.json_body["total"])
- self.assertEqual("@user:test", channel.json_body["users"][0]["name"])
-
- # Query non-deactivated users
- channel = self.make_request(
- "GET",
- f"{users_url_v3}?deactivated=false",
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, channel.result)
- self.assertEqual(1, channel.json_body["total"])
- self.assertEqual("@admin:test", channel.json_body["users"][0]["name"])
-
@override_config(
{
"experimental_features": {
@@ -1181,7 +1129,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# They should appear in the list users API, marked as not erased.
channel = self.make_request(
"GET",
- f"{self.url}?deactivated=true",
+ self.url + "?deactivated=true",
access_token=self.admin_user_tok,
)
users = {user["name"]: user for user in channel.json_body["users"]}
@@ -1245,7 +1193,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
dir: The direction of ordering to give the server
"""
- url = f"{self.url}?deactivated=true&"
+ url = self.url + "?deactivated=true&"
if order_by is not None:
url += "order_by=%s&" % (order_by,)
if dir is not None and dir in ("b", "f"):
@@ -5006,86 +4954,3 @@ class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase):
)
assert timestamp is not None
self.assertGreater(timestamp, self.clock.time_msec())
-
-
-class UserSuspensionTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- admin.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.admin = self.register_user("thomas", "hackme", True)
- self.admin_tok = self.login("thomas", "hackme")
-
- self.bad_user = self.register_user("teresa", "hackme")
- self.bad_user_tok = self.login("teresa", "hackme")
-
- self.store = hs.get_datastores().main
-
- @override_config({"experimental_features": {"msc3823_account_suspension": True}})
- def test_suspend_user(self) -> None:
- # test that suspending user works
- channel = self.make_request(
- "PUT",
- f"/_synapse/admin/v1/suspend/{self.bad_user}",
- {"suspend": True},
- access_token=self.admin_tok,
- )
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body, {f"user_{self.bad_user}_suspended": True})
-
- res = self.get_success(self.store.get_user_suspended_status(self.bad_user))
- self.assertEqual(True, res)
-
- # test that un-suspending user works
- channel2 = self.make_request(
- "PUT",
- f"/_synapse/admin/v1/suspend/{self.bad_user}",
- {"suspend": False},
- access_token=self.admin_tok,
- )
- self.assertEqual(channel2.code, 200)
- self.assertEqual(channel2.json_body, {f"user_{self.bad_user}_suspended": False})
-
- res2 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
- self.assertEqual(False, res2)
-
- # test that trying to un-suspend user who isn't suspended doesn't cause problems
- channel3 = self.make_request(
- "PUT",
- f"/_synapse/admin/v1/suspend/{self.bad_user}",
- {"suspend": False},
- access_token=self.admin_tok,
- )
- self.assertEqual(channel3.code, 200)
- self.assertEqual(channel3.json_body, {f"user_{self.bad_user}_suspended": False})
-
- res3 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
- self.assertEqual(False, res3)
-
- # test that trying to suspend user who is already suspended doesn't cause problems
- channel4 = self.make_request(
- "PUT",
- f"/_synapse/admin/v1/suspend/{self.bad_user}",
- {"suspend": True},
- access_token=self.admin_tok,
- )
- self.assertEqual(channel4.code, 200)
- self.assertEqual(channel4.json_body, {f"user_{self.bad_user}_suspended": True})
-
- res4 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
- self.assertEqual(True, res4)
-
- channel5 = self.make_request(
- "PUT",
- f"/_synapse/admin/v1/suspend/{self.bad_user}",
- {"suspend": True},
- access_token=self.admin_tok,
- )
- self.assertEqual(channel5.code, 200)
- self.assertEqual(channel5.json_body, {f"user_{self.bad_user}_suspended": True})
-
- res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
- self.assertEqual(True, res5)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 4dd5de33d3..d302be33a3 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/__init__.py b/tests/rest/client/__init__.py
index 6a72062b0c..3d833a2e44 100644
--- a/tests/rest/client/__init__.py
+++ b/tests/rest/client/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/sliding_sync/__init__.py b/tests/rest/client/sliding_sync/__init__.py
deleted file mode 100644
index c4de9d53e2..0000000000
--- a/tests/rest/client/sliding_sync/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py
deleted file mode 100644
index 4d8866b30a..0000000000
--- a/tests/rest/client/sliding_sync/test_connection_tracking.py
+++ /dev/null
@@ -1,453 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from parameterized import parameterized
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EventTypes
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.types import SlidingSyncStreamToken
-from synapse.types.handlers import SlidingSyncConfig
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase):
- """
- Test connection tracking in the Sliding Sync API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
-
- def test_rooms_required_state_incremental_sync_LIVE(self) -> None:
- """Test that we only get state updates in incremental sync for rooms
- we've already seen (LIVE).
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.RoomHistoryVisibility, ""],
- # This one doesn't exist in the room
- [EventTypes.Name, ""],
- ],
- "timeline_limit": 0,
- }
- }
- }
-
- response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.RoomHistoryVisibility, "")],
- },
- exact=True,
- )
-
- # Send a state event
- self.helper.send_state(
- room_id1, EventTypes.Name, body={"name": "foo"}, tok=user2_tok
- )
-
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self.assertNotIn("initial", response_body["rooms"][room_id1])
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Name, "")],
- },
- exact=True,
- )
-
- @parameterized.expand([(False,), (True,)])
- def test_rooms_timeline_incremental_sync_PREVIOUSLY(self, limited: bool) -> None:
- """
- Test getting room data where we have previously sent down the room, but
- we missed sending down some timeline events previously and so its status
- is considered PREVIOUSLY.
-
- There are two versions of this test, one where there are more messages
- than the timeline limit, and one where there isn't.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- self.helper.send(room_id1, "msg", tok=user1_tok)
-
- timeline_limit = 5
- conn_id = "conn_id"
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [],
- "timeline_limit": timeline_limit,
- }
- },
- "conn_id": "conn_id",
- }
-
- # The first room gets sent down the initial sync
- response_body, initial_from_token = self.do_sync(sync_body, tok=user1_tok)
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id1}, response_body["rooms"]
- )
-
- # We now send down some events in room1 (depending on the test param).
- expected_events = [] # The set of events in the timeline
- if limited:
- for _ in range(10):
- resp = self.helper.send(room_id1, "msg1", tok=user1_tok)
- expected_events.append(resp["event_id"])
- else:
- resp = self.helper.send(room_id1, "msg1", tok=user1_tok)
- expected_events.append(resp["event_id"])
-
- # A second messages happens in the other room, so room1 won't get sent down.
- self.helper.send(room_id2, "msg", tok=user1_tok)
-
- # Only the second room gets sent down sync.
- response_body, from_token = self.do_sync(
- sync_body, since=initial_from_token, tok=user1_tok
- )
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id2}, response_body["rooms"]
- )
-
- # FIXME: This is a hack to record that the first room wasn't sent down
- # sync, as we don't implement that currently.
- sliding_sync_handler = self.hs.get_sliding_sync_handler()
- requester = self.get_success(
- self.hs.get_auth().get_user_by_access_token(user1_tok)
- )
- sync_config = SlidingSyncConfig(
- user=requester.user,
- requester=requester,
- conn_id=conn_id,
- )
-
- parsed_initial_from_token = self.get_success(
- SlidingSyncStreamToken.from_string(self.store, initial_from_token)
- )
- connection_position = self.get_success(
- sliding_sync_handler.connection_store.record_rooms(
- sync_config,
- parsed_initial_from_token,
- sent_room_ids=[],
- unsent_room_ids=[room_id1],
- )
- )
-
- # FIXME: Now fix up `from_token` with new connect position above.
- parsed_from_token = self.get_success(
- SlidingSyncStreamToken.from_string(self.store, from_token)
- )
- parsed_from_token = SlidingSyncStreamToken(
- stream_token=parsed_from_token.stream_token,
- connection_position=connection_position,
- )
- from_token = self.get_success(parsed_from_token.to_string(self.store))
-
- # We now send another event to room1, so we should sync all the missing events.
- resp = self.helper.send(room_id1, "msg2", tok=user1_tok)
- expected_events.append(resp["event_id"])
-
- # This sync should contain the messages from room1 not yet sent down.
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id1}, response_body["rooms"]
- )
- self.assertNotIn("initial", response_body["rooms"][room_id1])
-
- self.assertEqual(
- [ev["event_id"] for ev in response_body["rooms"][room_id1]["timeline"]],
- expected_events[-timeline_limit:],
- )
- self.assertEqual(response_body["rooms"][room_id1]["limited"], limited)
- self.assertEqual(response_body["rooms"][room_id1].get("required_state"), None)
-
- def test_rooms_required_state_incremental_sync_PREVIOUSLY(self) -> None:
- """
- Test getting room data where we have previously sent down the room, but
- we missed sending down some state previously and so its status is
- considered PREVIOUSLY.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- self.helper.send(room_id1, "msg", tok=user1_tok)
-
- conn_id = "conn_id"
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.RoomHistoryVisibility, ""],
- # This one doesn't exist in the room
- [EventTypes.Name, ""],
- ],
- "timeline_limit": 0,
- }
- },
- "conn_id": "conn_id",
- }
-
- # The first room gets sent down the initial sync
- response_body, initial_from_token = self.do_sync(sync_body, tok=user1_tok)
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id1}, response_body["rooms"]
- )
-
- # We now send down some state in room1
- resp = self.helper.send_state(
- room_id1, EventTypes.Name, {"name": "foo"}, tok=user1_tok
- )
- name_change_id = resp["event_id"]
-
- # A second messages happens in the other room, so room1 won't get sent down.
- self.helper.send(room_id2, "msg", tok=user1_tok)
-
- # Only the second room gets sent down sync.
- response_body, from_token = self.do_sync(
- sync_body, since=initial_from_token, tok=user1_tok
- )
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id2}, response_body["rooms"]
- )
-
- # FIXME: This is a hack to record that the first room wasn't sent down
- # sync, as we don't implement that currently.
- sliding_sync_handler = self.hs.get_sliding_sync_handler()
- requester = self.get_success(
- self.hs.get_auth().get_user_by_access_token(user1_tok)
- )
- sync_config = SlidingSyncConfig(
- user=requester.user,
- requester=requester,
- conn_id=conn_id,
- )
-
- parsed_initial_from_token = self.get_success(
- SlidingSyncStreamToken.from_string(self.store, initial_from_token)
- )
- connection_position = self.get_success(
- sliding_sync_handler.connection_store.record_rooms(
- sync_config,
- parsed_initial_from_token,
- sent_room_ids=[],
- unsent_room_ids=[room_id1],
- )
- )
-
- # FIXME: Now fix up `from_token` with new connect position above.
- parsed_from_token = self.get_success(
- SlidingSyncStreamToken.from_string(self.store, from_token)
- )
- parsed_from_token = SlidingSyncStreamToken(
- stream_token=parsed_from_token.stream_token,
- connection_position=connection_position,
- )
- from_token = self.get_success(parsed_from_token.to_string(self.store))
-
- # We now send another event to room1, so we should sync all the missing state.
- self.helper.send(room_id1, "msg", tok=user1_tok)
-
- # This sync should contain the state changes from room1.
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id1}, response_body["rooms"]
- )
- self.assertNotIn("initial", response_body["rooms"][room_id1])
-
- # We should only see the name change.
- self.assertEqual(
- [
- ev["event_id"]
- for ev in response_body["rooms"][room_id1]["required_state"]
- ],
- [name_change_id],
- )
-
- def test_rooms_required_state_incremental_sync_NEVER(self) -> None:
- """
- Test getting `required_state` where we have NEVER sent down the room before
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- self.helper.send(room_id1, "msg", tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.RoomHistoryVisibility, ""],
- # This one doesn't exist in the room
- [EventTypes.Name, ""],
- ],
- "timeline_limit": 1,
- }
- },
- }
-
- # A message happens in the other room, so room1 won't get sent down.
- self.helper.send(room_id2, "msg", tok=user1_tok)
-
- # Only the second room gets sent down sync.
- response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id2}, response_body["rooms"]
- )
-
- # We now send another event to room1, so we should send down the full
- # room.
- self.helper.send(room_id1, "msg2", tok=user1_tok)
-
- # This sync should contain the messages from room1 not yet sent down.
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id1}, response_body["rooms"]
- )
-
- self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.RoomHistoryVisibility, "")],
- },
- exact=True,
- )
-
- def test_rooms_timeline_incremental_sync_NEVER(self) -> None:
- """
- Test getting timeline room data where we have NEVER sent down the room
- before
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [],
- "timeline_limit": 5,
- }
- },
- }
-
- expected_events = []
- for _ in range(4):
- resp = self.helper.send(room_id1, "msg", tok=user1_tok)
- expected_events.append(resp["event_id"])
-
- # A message happens in the other room, so room1 won't get sent down.
- self.helper.send(room_id2, "msg", tok=user1_tok)
-
- # Only the second room gets sent down sync.
- response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id2}, response_body["rooms"]
- )
-
- # We now send another event to room1 so it comes down sync
- resp = self.helper.send(room_id1, "msg2", tok=user1_tok)
- expected_events.append(resp["event_id"])
-
- # This sync should contain the messages from room1 not yet sent down.
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertCountEqual(
- response_body["rooms"].keys(), {room_id1}, response_body["rooms"]
- )
-
- self.assertEqual(
- [ev["event_id"] for ev in response_body["rooms"][room_id1]["timeline"]],
- expected_events,
- )
- self.assertEqual(response_body["rooms"][room_id1]["limited"], True)
- self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
diff --git a/tests/rest/client/sliding_sync/test_extension_account_data.py b/tests/rest/client/sliding_sync/test_extension_account_data.py
deleted file mode 100644
index 3482a5f887..0000000000
--- a/tests/rest/client/sliding_sync/test_extension_account_data.py
+++ /dev/null
@@ -1,495 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import AccountDataTypes
-from synapse.rest.client import login, room, sendtodevice, sync
-from synapse.server import HomeServer
-from synapse.types import StreamKeyType
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.server import TimedOutException
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
- """Tests for the account_data sliding sync extension"""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- sendtodevice.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.account_data_handler = hs.get_account_data_handler()
-
- def test_no_data_initial_sync(self) -> None:
- """
- Test that enabling the account_data extension works during an intitial sync,
- even if there is no-data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Make an initial Sliding Sync request with the account_data extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "account_data": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
- # Even though we don't have any global account data set, Synapse saves some
- # default push rules for us.
- {AccountDataTypes.PUSH_RULES},
- exact=True,
- )
- self.assertIncludes(
- response_body["extensions"]["account_data"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_no_data_incremental_sync(self) -> None:
- """
- Test that enabling account_data extension works during an incremental sync, even
- if there is no-data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "account_data": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the account_data extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # There has been no account data changes since the `from_token` so we shouldn't
- # see any account data here.
- self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
- set(),
- exact=True,
- )
- self.assertIncludes(
- response_body["extensions"]["account_data"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_global_account_data_initial_sync(self) -> None:
- """
- On initial sync, we should return all global account data on initial sync.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Update the global account data
- self.get_success(
- self.account_data_handler.add_account_data_for_user(
- user_id=user1_id,
- account_data_type="org.matrix.foobarbaz",
- content={"foo": "bar"},
- )
- )
-
- # Make an initial Sliding Sync request with the account_data extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "account_data": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # It should show us all of the global account data
- self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
- {AccountDataTypes.PUSH_RULES, "org.matrix.foobarbaz"},
- exact=True,
- )
- self.assertIncludes(
- response_body["extensions"]["account_data"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_global_account_data_incremental_sync(self) -> None:
- """
- On incremental sync, we should only account data that has changed since the
- `from_token`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Add some global account data
- self.get_success(
- self.account_data_handler.add_account_data_for_user(
- user_id=user1_id,
- account_data_type="org.matrix.foobarbaz",
- content={"foo": "bar"},
- )
- )
-
- sync_body = {
- "lists": {},
- "extensions": {
- "account_data": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Add some other global account data
- self.get_success(
- self.account_data_handler.add_account_data_for_user(
- user_id=user1_id,
- account_data_type="org.matrix.doodardaz",
- content={"doo": "dar"},
- )
- )
-
- # Make an incremental Sliding Sync request with the account_data extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
- # We should only see the new global account data that happened after the `from_token`
- {"org.matrix.doodardaz"},
- exact=True,
- )
- self.assertIncludes(
- response_body["extensions"]["account_data"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_room_account_data_initial_sync(self) -> None:
- """
- On initial sync, we return all account data for a given room but only for
- rooms that we request and are being returned in the Sliding Sync response.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a room and add some room account data
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id1,
- account_data_type="org.matrix.roorarraz",
- content={"roo": "rar"},
- )
- )
-
- # Create another room with some room account data
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id2,
- account_data_type="org.matrix.roorarraz",
- content={"roo": "rar"},
- )
- )
-
- # Make an initial Sliding Sync request with the account_data extension enabled
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- "timeline_limit": 0,
- }
- },
- "extensions": {
- "account_data": {
- "enabled": True,
- "rooms": [room_id1, room_id2],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertIsNotNone(response_body["extensions"]["account_data"].get("global"))
- # Even though we requested room2, we only expect room1 to show up because that's
- # the only room in the Sliding Sync response (room2 is not one of our room
- # subscriptions or in a sliding window list).
- self.assertIncludes(
- response_body["extensions"]["account_data"].get("rooms").keys(),
- {room_id1},
- exact=True,
- )
- self.assertIncludes(
- {
- event["type"]
- for event in response_body["extensions"]["account_data"]
- .get("rooms")
- .get(room_id1)
- },
- {"org.matrix.roorarraz"},
- exact=True,
- )
-
- def test_room_account_data_incremental_sync(self) -> None:
- """
- On incremental sync, we return all account data for a given room but only for
- rooms that we request and are being returned in the Sliding Sync response.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a room and add some room account data
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id1,
- account_data_type="org.matrix.roorarraz",
- content={"roo": "rar"},
- )
- )
-
- # Create another room with some room account data
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id2,
- account_data_type="org.matrix.roorarraz",
- content={"roo": "rar"},
- )
- )
-
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- "timeline_limit": 0,
- }
- },
- "extensions": {
- "account_data": {
- "enabled": True,
- "rooms": [room_id1, room_id2],
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Add some other room account data
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id1,
- account_data_type="org.matrix.roorarraz2",
- content={"roo": "rar"},
- )
- )
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id2,
- account_data_type="org.matrix.roorarraz2",
- content={"roo": "rar"},
- )
- )
-
- # Make an incremental Sliding Sync request with the account_data extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertIsNotNone(response_body["extensions"]["account_data"].get("global"))
- # Even though we requested room2, we only expect room1 to show up because that's
- # the only room in the Sliding Sync response (room2 is not one of our room
- # subscriptions or in a sliding window list).
- self.assertIncludes(
- response_body["extensions"]["account_data"].get("rooms").keys(),
- {room_id1},
- exact=True,
- )
- # We should only see the new room account data that happened after the `from_token`
- self.assertIncludes(
- {
- event["type"]
- for event in response_body["extensions"]["account_data"]
- .get("rooms")
- .get(room_id1)
- },
- {"org.matrix.roorarraz2"},
- exact=True,
- )
-
- def test_wait_for_new_data(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive.
-
- (Only applies to incremental syncs with a `timeout` specified)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {},
- "extensions": {
- "account_data": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the account_data extension enabled
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Bump the global account data to trigger new results
- self.get_success(
- self.account_data_handler.add_account_data_for_user(
- user1_id,
- "org.matrix.foobarbaz",
- {"foo": "bar"},
- )
- )
- # Should respond before the 10 second timeout
- channel.await_result(timeout_ms=3000)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # We should see the global account data update
- self.assertIncludes(
- {
- global_event["type"]
- for global_event in channel.json_body["extensions"]["account_data"].get(
- "global"
- )
- },
- {"org.matrix.foobarbaz"},
- exact=True,
- )
- self.assertIncludes(
- channel.json_body["extensions"]["account_data"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_wait_for_new_data_timeout(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive but
- no data ever arrives so we timeout. We're also making sure that the default data
- from the account_data extension doesn't trigger a false-positive for new data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "account_data": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Wake-up `notifier.wait_for_events(...)` that will cause us test
- # `SlidingSyncResult.__bool__` for new results.
- self._bump_notifier_wait_for_events(
- user1_id,
- # We choose `StreamKeyType.PRESENCE` because we're testing for account data
- # and don't want to contaminate the account data results using
- # `StreamKeyType.ACCOUNT_DATA`.
- wake_stream_key=StreamKeyType.PRESENCE,
- )
- # Block for a little bit more to ensure we don't see any new results.
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=4000)
- # Wait for the sync to complete (wait for the rest of the 10 second timeout,
- # 5000 + 4000 + 1200 > 10000)
- channel.await_result(timeout_ms=1200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- self.assertIsNotNone(
- channel.json_body["extensions"]["account_data"].get("global")
- )
- self.assertIsNotNone(
- channel.json_body["extensions"]["account_data"].get("rooms")
- )
diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py
deleted file mode 100644
index 320f8c788f..0000000000
--- a/tests/rest/client/sliding_sync/test_extension_e2ee.py
+++ /dev/null
@@ -1,441 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.rest.client import devices, login, room, sync
-from synapse.server import HomeServer
-from synapse.types import JsonDict, StreamKeyType
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.server import TimedOutException
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
- """Tests for the e2ee sliding sync extension"""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.e2e_keys_handler = hs.get_e2e_keys_handler()
-
- def test_no_data_initial_sync(self) -> None:
- """
- Test that enabling e2ee extension works during an intitial sync, even if there
- is no-data
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Make an initial Sliding Sync request with the e2ee extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Device list updates are only present for incremental syncs
- self.assertIsNone(response_body["extensions"]["e2ee"].get("device_lists"))
-
- # Both of these should be present even when empty
- self.assertEqual(
- response_body["extensions"]["e2ee"]["device_one_time_keys_count"],
- {
- # This is always present because of
- # https://github.com/element-hq/element-android/issues/3725 and
- # https://github.com/matrix-org/synapse/issues/10456
- "signed_curve25519": 0
- },
- )
- self.assertEqual(
- response_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
- [],
- )
-
- def test_no_data_incremental_sync(self) -> None:
- """
- Test that enabling e2ee extension works during an incremental sync, even if
- there is no-data
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the e2ee extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Device list shows up for incremental syncs
- self.assertEqual(
- response_body["extensions"]["e2ee"].get("device_lists", {}).get("changed"),
- [],
- )
- self.assertEqual(
- response_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
- [],
- )
-
- # Both of these should be present even when empty
- self.assertEqual(
- response_body["extensions"]["e2ee"]["device_one_time_keys_count"],
- {
- # Note that "signed_curve25519" is always returned in key count responses
- # regardless of whether we uploaded any keys for it. This is necessary until
- # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
- #
- # Also related:
- # https://github.com/element-hq/element-android/issues/3725 and
- # https://github.com/matrix-org/synapse/issues/10456
- "signed_curve25519": 0
- },
- )
- self.assertEqual(
- response_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
- [],
- )
-
- def test_wait_for_new_data(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive.
-
- (Only applies to incremental syncs with a `timeout` specified)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- test_device_id = "TESTDEVICE"
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
- self.helper.join(room_id, user3_id, tok=user3_tok)
-
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Bump the device lists to trigger new results
- # Have user3 update their device list
- device_update_channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=user3_tok,
- )
- self.assertEqual(
- device_update_channel.code, 200, device_update_channel.json_body
- )
- # Should respond before the 10 second timeout
- channel.await_result(timeout_ms=3000)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # We should see the device list update
- self.assertEqual(
- channel.json_body["extensions"]["e2ee"]
- .get("device_lists", {})
- .get("changed"),
- [user3_id],
- )
- self.assertEqual(
- channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
- [],
- )
-
- def test_wait_for_new_data_timeout(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive but
- no data ever arrives so we timeout. We're also making sure that the default data
- from the E2EE extension doesn't trigger a false-positive for new data (see
- `device_one_time_keys_count.signed_curve25519`).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Wake-up `notifier.wait_for_events(...)` that will cause us test
- # `SlidingSyncResult.__bool__` for new results.
- self._bump_notifier_wait_for_events(
- user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA
- )
- # Block for a little bit more to ensure we don't see any new results.
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=4000)
- # Wait for the sync to complete (wait for the rest of the 10 second timeout,
- # 5000 + 4000 + 1200 > 10000)
- channel.await_result(timeout_ms=1200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Device lists are present for incremental syncs but empty because no device changes
- self.assertEqual(
- channel.json_body["extensions"]["e2ee"]
- .get("device_lists", {})
- .get("changed"),
- [],
- )
- self.assertEqual(
- channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
- [],
- )
-
- # Both of these should be present even when empty
- self.assertEqual(
- channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
- {
- # Note that "signed_curve25519" is always returned in key count responses
- # regardless of whether we uploaded any keys for it. This is necessary until
- # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
- #
- # Also related:
- # https://github.com/element-hq/element-android/issues/3725 and
- # https://github.com/matrix-org/synapse/issues/10456
- "signed_curve25519": 0
- },
- )
- self.assertEqual(
- channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
- [],
- )
-
- def test_device_lists(self) -> None:
- """
- Test that device list updates are included in the response
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- test_device_id = "TESTDEVICE"
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
-
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
- self.helper.join(room_id, user3_id, tok=user3_tok)
- self.helper.join(room_id, user4_id, tok=user4_tok)
-
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Have user3 update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=user3_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # User4 leaves the room
- self.helper.leave(room_id, user4_id, tok=user4_tok)
-
- # Make an incremental Sliding Sync request with the e2ee extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Device list updates show up
- self.assertEqual(
- response_body["extensions"]["e2ee"].get("device_lists", {}).get("changed"),
- [user3_id],
- )
- self.assertEqual(
- response_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
- [user4_id],
- )
-
- def test_device_one_time_keys_count(self) -> None:
- """
- Test that `device_one_time_keys_count` are included in the response
- """
- test_device_id = "TESTDEVICE"
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
-
- # Upload one time keys for the user/device
- keys: JsonDict = {
- "alg1:k1": "key1",
- "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
- "alg2:k3": {"key": "key3"},
- }
- upload_keys_response = self.get_success(
- self.e2e_keys_handler.upload_keys_for_user(
- user1_id, test_device_id, {"one_time_keys": keys}
- )
- )
- self.assertDictEqual(
- upload_keys_response,
- {
- "one_time_key_counts": {
- "alg1": 1,
- "alg2": 2,
- # Note that "signed_curve25519" is always returned in key count responses
- # regardless of whether we uploaded any keys for it. This is necessary until
- # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
- #
- # Also related:
- # https://github.com/element-hq/element-android/issues/3725 and
- # https://github.com/matrix-org/synapse/issues/10456
- "signed_curve25519": 0,
- }
- },
- )
-
- # Make a Sliding Sync request with the e2ee extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Check for those one time key counts
- self.assertEqual(
- response_body["extensions"]["e2ee"].get("device_one_time_keys_count"),
- {
- "alg1": 1,
- "alg2": 2,
- # Note that "signed_curve25519" is always returned in key count responses
- # regardless of whether we uploaded any keys for it. This is necessary until
- # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
- #
- # Also related:
- # https://github.com/element-hq/element-android/issues/3725 and
- # https://github.com/matrix-org/synapse/issues/10456
- "signed_curve25519": 0,
- },
- )
-
- def test_device_unused_fallback_key_types(self) -> None:
- """
- Test that `device_unused_fallback_key_types` are included in the response
- """
- test_device_id = "TESTDEVICE"
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
-
- # We shouldn't have any unused fallback keys yet
- res = self.get_success(
- self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
- )
- self.assertEqual(res, [])
-
- # Upload a fallback key for the user/device
- self.get_success(
- self.e2e_keys_handler.upload_keys_for_user(
- user1_id,
- test_device_id,
- {"fallback_keys": {"alg1:k1": "fallback_key1"}},
- )
- )
- # We should now have an unused alg1 key
- fallback_res = self.get_success(
- self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
- )
- self.assertEqual(fallback_res, ["alg1"], fallback_res)
-
- # Make a Sliding Sync request with the e2ee extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "e2ee": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Check for the unused fallback key types
- self.assertListEqual(
- response_body["extensions"]["e2ee"].get("device_unused_fallback_key_types"),
- ["alg1"],
- )
diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py
deleted file mode 100644
index 65fbac260e..0000000000
--- a/tests/rest/client/sliding_sync/test_extension_receipts.py
+++ /dev/null
@@ -1,679 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EduTypes, ReceiptTypes
-from synapse.rest.client import login, receipts, room, sync
-from synapse.server import HomeServer
-from synapse.types import StreamKeyType
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.server import TimedOutException
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
- """Tests for the receipts sliding sync extension"""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- receipts.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- def test_no_data_initial_sync(self) -> None:
- """
- Test that enabling the receipts extension works during an intitial sync,
- even if there is no-data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Make an initial Sliding Sync request with the receipts extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "receipts": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertIncludes(
- response_body["extensions"]["receipts"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_no_data_incremental_sync(self) -> None:
- """
- Test that enabling receipts extension works during an incremental sync, even
- if there is no-data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "receipts": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the receipts extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertIncludes(
- response_body["extensions"]["receipts"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_receipts_initial_sync_with_timeline(self) -> None:
- """
- On initial sync, we only return receipts for events in a given room's timeline.
-
- We also make sure that we only return receipts for rooms that we request and are
- already being returned in the Sliding Sync response.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
-
- # Create a room
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
- self.helper.join(room_id1, user4_id, tok=user4_tok)
- room1_event_response1 = self.helper.send(
- room_id1, body="new event1", tok=user2_tok
- )
- room1_event_response2 = self.helper.send(
- room_id1, body="new event2", tok=user2_tok
- )
- # User1 reads the last event
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response2['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User2 reads the last event
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response2['event_id']}",
- {},
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User3 reads the first event
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}",
- {},
- access_token=user3_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User4 privately reads the last event (make sure this doesn't leak to the other users)
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ_PRIVATE}/{room1_event_response2['event_id']}",
- {},
- access_token=user4_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Create another room
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- self.helper.join(room_id2, user3_id, tok=user3_tok)
- self.helper.join(room_id2, user4_id, tok=user4_tok)
- room2_event_response1 = self.helper.send(
- room_id2, body="new event2", tok=user2_tok
- )
- # User1 reads the last event
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ}/{room2_event_response1['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User2 reads the last event
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ}/{room2_event_response1['event_id']}",
- {},
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User4 privately reads the last event (make sure this doesn't leak to the other users)
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ_PRIVATE}/{room2_event_response1['event_id']}",
- {},
- access_token=user4_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Make an initial Sliding Sync request with the receipts extension enabled
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- # On initial sync, we only have receipts for events in the timeline
- "timeline_limit": 1,
- }
- },
- "extensions": {
- "receipts": {
- "enabled": True,
- "rooms": [room_id1, room_id2],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Only the latest event in the room is in the timelie because the `timeline_limit` is 1
- self.assertIncludes(
- {
- event["event_id"]
- for event in response_body["rooms"][room_id1].get("timeline", [])
- },
- {room1_event_response2["event_id"]},
- exact=True,
- message=str(response_body["rooms"][room_id1]),
- )
-
- # Even though we requested room2, we only expect room1 to show up because that's
- # the only room in the Sliding Sync response (room2 is not one of our room
- # subscriptions or in a sliding window list).
- self.assertIncludes(
- response_body["extensions"]["receipts"].get("rooms").keys(),
- {room_id1},
- exact=True,
- )
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["type"],
- EduTypes.RECEIPT,
- )
- # We can see user1 and user2 read receipts
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][
- room1_event_response2["event_id"]
- ][ReceiptTypes.READ].keys(),
- {user1_id, user2_id},
- exact=True,
- )
- # User1 did not have a private read receipt and we shouldn't leak others'
- # private read receipts
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][
- room1_event_response2["event_id"]
- ]
- .get(ReceiptTypes.READ_PRIVATE, {})
- .keys(),
- set(),
- exact=True,
- )
-
- # We shouldn't see receipts for event2 since it wasn't in the timeline and this is an initial sync
- self.assertIsNone(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"].get(
- room1_event_response1["event_id"]
- )
- )
-
- def test_receipts_incremental_sync(self) -> None:
- """
- On incremental sync, we return all receipts in the token range for a given room
- but only for rooms that we request and are being returned in the Sliding Sync
- response.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
-
- # Create room1
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
- room1_event_response1 = self.helper.send(
- room_id1, body="new event2", tok=user2_tok
- )
- # User2 reads the last event (before the `from_token`)
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}",
- {},
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Create room2
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- room2_event_response1 = self.helper.send(
- room_id2, body="new event2", tok=user2_tok
- )
- # User1 reads the last event (before the `from_token`)
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ}/{room2_event_response1['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Create room3
- room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id3, user1_id, tok=user1_tok)
- self.helper.join(room_id3, user3_id, tok=user3_tok)
- room3_event_response1 = self.helper.send(
- room_id3, body="new event", tok=user2_tok
- )
-
- # Create room4
- room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id4, user1_id, tok=user1_tok)
- self.helper.join(room_id4, user3_id, tok=user3_tok)
- event_response4 = self.helper.send(room_id4, body="new event", tok=user2_tok)
- # User1 reads the last event (before the `from_token`)
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id4}/receipt/{ReceiptTypes.READ}/{event_response4['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- "timeline_limit": 0,
- },
- room_id3: {
- "required_state": [],
- "timeline_limit": 0,
- },
- room_id4: {
- "required_state": [],
- "timeline_limit": 0,
- },
- },
- "extensions": {
- "receipts": {
- "enabled": True,
- "rooms": [room_id1, room_id2, room_id3, room_id4],
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Add some more read receipts after the `from_token`
- #
- # User1 reads room1
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User1 privately reads room2
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id2}/receipt/{ReceiptTypes.READ_PRIVATE}/{room2_event_response1['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User3 reads room3
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id3}/receipt/{ReceiptTypes.READ}/{room3_event_response1['event_id']}",
- {},
- access_token=user3_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # No activity for room4 after the `from_token`
-
- # Make an incremental Sliding Sync request with the receipts extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Even though we requested room2, we only expect rooms to show up if they are
- # already in the Sliding Sync response. room4 doesn't show up because there is
- # no activity after the `from_token`.
- self.assertIncludes(
- response_body["extensions"]["receipts"].get("rooms").keys(),
- {room_id1, room_id3},
- exact=True,
- )
-
- # Check room1:
- #
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["type"],
- EduTypes.RECEIPT,
- )
- # We only see that user1 has read something in room1 since the `from_token`
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][
- room1_event_response1["event_id"]
- ][ReceiptTypes.READ].keys(),
- {user1_id},
- exact=True,
- )
- # User1 did not send a private read receipt in this room and we shouldn't leak
- # others' private read receipts
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][
- room1_event_response1["event_id"]
- ]
- .get(ReceiptTypes.READ_PRIVATE, {})
- .keys(),
- set(),
- exact=True,
- )
- # No events in the timeline since they were sent before the `from_token`
- self.assertNotIn(room_id1, response_body["rooms"])
-
- # Check room3:
- #
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["receipts"]["rooms"][room_id3]["type"],
- EduTypes.RECEIPT,
- )
- # We only see that user3 has read something in room1 since the `from_token`
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id3]["content"][
- room3_event_response1["event_id"]
- ][ReceiptTypes.READ].keys(),
- {user3_id},
- exact=True,
- )
- # User1 did not send a private read receipt in this room and we shouldn't leak
- # others' private read receipts
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id3]["content"][
- room3_event_response1["event_id"]
- ]
- .get(ReceiptTypes.READ_PRIVATE, {})
- .keys(),
- set(),
- exact=True,
- )
- # No events in the timeline since they were sent before the `from_token`
- self.assertNotIn(room_id3, response_body["rooms"])
-
- def test_receipts_incremental_sync_all_live_receipts(self) -> None:
- """
- On incremental sync, we return all receipts in the token range for a given room
- even if they are not in the timeline.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create room1
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- # The timeline will only include event2
- "timeline_limit": 1,
- },
- },
- "extensions": {
- "receipts": {
- "enabled": True,
- "rooms": [room_id1],
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- room1_event_response1 = self.helper.send(
- room_id1, body="new event1", tok=user2_tok
- )
- room1_event_response2 = self.helper.send(
- room_id1, body="new event2", tok=user2_tok
- )
-
- # User1 reads event1
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response1['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User2 reads event2
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id1}/receipt/{ReceiptTypes.READ}/{room1_event_response2['event_id']}",
- {},
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Make an incremental Sliding Sync request with the receipts extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # We should see room1 because it has receipts in the token range
- self.assertIncludes(
- response_body["extensions"]["receipts"].get("rooms").keys(),
- {room_id1},
- exact=True,
- )
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["type"],
- EduTypes.RECEIPT,
- )
- # We should see all receipts in the token range regardless of whether the events
- # are in the timeline
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][
- room1_event_response1["event_id"]
- ][ReceiptTypes.READ].keys(),
- {user1_id},
- exact=True,
- )
- self.assertIncludes(
- response_body["extensions"]["receipts"]["rooms"][room_id1]["content"][
- room1_event_response2["event_id"]
- ][ReceiptTypes.READ].keys(),
- {user2_id},
- exact=True,
- )
- # Only the latest event in the timeline because the `timeline_limit` is 1
- self.assertIncludes(
- {
- event["event_id"]
- for event in response_body["rooms"][room_id1].get("timeline", [])
- },
- {room1_event_response2["event_id"]},
- exact=True,
- message=str(response_body["rooms"][room_id1]),
- )
-
- def test_wait_for_new_data(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive.
-
- (Only applies to incremental syncs with a `timeout` specified)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
- event_response = self.helper.send(room_id, body="new event", tok=user2_tok)
-
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id: {
- "required_state": [],
- "timeline_limit": 0,
- },
- },
- "extensions": {
- "receipts": {
- "enabled": True,
- "rooms": [room_id],
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the receipts extension enabled
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Bump the receipts to trigger new results
- receipt_channel = self.make_request(
- "POST",
- f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_response['event_id']}",
- {},
- access_token=user2_tok,
- )
- self.assertEqual(receipt_channel.code, 200, receipt_channel.json_body)
- # Should respond before the 10 second timeout
- channel.await_result(timeout_ms=3000)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # We should see the new receipt
- self.assertIncludes(
- channel.json_body.get("extensions", {})
- .get("receipts", {})
- .get("rooms", {})
- .keys(),
- {room_id},
- exact=True,
- message=str(channel.json_body),
- )
- self.assertIncludes(
- channel.json_body["extensions"]["receipts"]["rooms"][room_id]["content"][
- event_response["event_id"]
- ][ReceiptTypes.READ].keys(),
- {user2_id},
- exact=True,
- )
- # User1 did not send a private read receipt in this room and we shouldn't leak
- # others' private read receipts
- self.assertIncludes(
- channel.json_body["extensions"]["receipts"]["rooms"][room_id]["content"][
- event_response["event_id"]
- ]
- .get(ReceiptTypes.READ_PRIVATE, {})
- .keys(),
- set(),
- exact=True,
- )
-
- def test_wait_for_new_data_timeout(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive but
- no data ever arrives so we timeout. We're also making sure that the default data
- from the receipts extension doesn't trigger a false-positive for new data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "receipts": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Wake-up `notifier.wait_for_events(...)` that will cause us test
- # `SlidingSyncResult.__bool__` for new results.
- self._bump_notifier_wait_for_events(
- user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA
- )
- # Block for a little bit more to ensure we don't see any new results.
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=4000)
- # Wait for the sync to complete (wait for the rest of the 10 second timeout,
- # 5000 + 4000 + 1200 > 10000)
- channel.await_result(timeout_ms=1200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- self.assertIncludes(
- channel.json_body["extensions"]["receipts"].get("rooms").keys(),
- set(),
- exact=True,
- )
diff --git a/tests/rest/client/sliding_sync/test_extension_to_device.py b/tests/rest/client/sliding_sync/test_extension_to_device.py
deleted file mode 100644
index f8500812ea..0000000000
--- a/tests/rest/client/sliding_sync/test_extension_to_device.py
+++ /dev/null
@@ -1,278 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-from typing import List
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.rest.client import login, sendtodevice, sync
-from synapse.server import HomeServer
-from synapse.types import JsonDict, StreamKeyType
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.server import TimedOutException
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
- """Tests for the to-device sliding sync extension"""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- sync.register_servlets,
- sendtodevice.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- def _assert_to_device_response(
- self, response_body: JsonDict, expected_messages: List[JsonDict]
- ) -> str:
- """Assert the sliding sync response was successful and has the expected
- to-device messages.
-
- Returns the next_batch token from the to-device section.
- """
- extensions = response_body["extensions"]
- to_device = extensions["to_device"]
- self.assertIsInstance(to_device["next_batch"], str)
- self.assertEqual(to_device["events"], expected_messages)
-
- return to_device["next_batch"]
-
- def test_no_data(self) -> None:
- """Test that enabling to-device extension works, even if there is
- no-data
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # We expect no to-device messages
- self._assert_to_device_response(response_body, [])
-
- def test_data_initial_sync(self) -> None:
- """Test that we get to-device messages when we don't specify a since
- token"""
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass", "d1")
- user2_id = self.register_user("u2", "pass")
- user2_tok = self.login(user2_id, "pass", "d2")
-
- # Send the to-device message
- test_msg = {"foo": "bar"}
- chan = self.make_request(
- "PUT",
- "/_matrix/client/r0/sendToDevice/m.test/1234",
- content={"messages": {user1_id: {"d1": test_msg}}},
- access_token=user2_tok,
- )
- self.assertEqual(chan.code, 200, chan.result)
-
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- self._assert_to_device_response(
- response_body,
- [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
- )
-
- def test_data_incremental_sync(self) -> None:
- """Test that we get to-device messages over incremental syncs"""
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass", "d1")
- user2_id = self.register_user("u2", "pass")
- user2_tok = self.login(user2_id, "pass", "d2")
-
- sync_body: JsonDict = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # No to-device messages yet.
- next_batch = self._assert_to_device_response(response_body, [])
-
- test_msg = {"foo": "bar"}
- chan = self.make_request(
- "PUT",
- "/_matrix/client/r0/sendToDevice/m.test/1234",
- content={"messages": {user1_id: {"d1": test_msg}}},
- access_token=user2_tok,
- )
- self.assertEqual(chan.code, 200, chan.result)
-
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- "since": next_batch,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- next_batch = self._assert_to_device_response(
- response_body,
- [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
- )
-
- # The next sliding sync request should not include the to-device
- # message.
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- "since": next_batch,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- self._assert_to_device_response(response_body, [])
-
- # An initial sliding sync request should not include the to-device
- # message, as it should have been deleted
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- self._assert_to_device_response(response_body, [])
-
- def test_wait_for_new_data(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive.
-
- (Only applies to incremental syncs with a `timeout` specified)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass", "d1")
- user2_id = self.register_user("u2", "pass")
- user2_tok = self.login(user2_id, "pass", "d2")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Bump the to-device messages to trigger new results
- test_msg = {"foo": "bar"}
- send_to_device_channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/sendToDevice/m.test/1234",
- content={"messages": {user1_id: {"d1": test_msg}}},
- access_token=user2_tok,
- )
- self.assertEqual(
- send_to_device_channel.code, 200, send_to_device_channel.result
- )
- # Should respond before the 10 second timeout
- channel.await_result(timeout_ms=3000)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- self._assert_to_device_response(
- channel.json_body,
- [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
- )
-
- def test_wait_for_new_data_timeout(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive but
- no data ever arrives so we timeout. We're also making sure that the default data
- from the To-Device extension doesn't trigger a false-positive for new data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "to_device": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + "?timeout=10000" + f"&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Wake-up `notifier.wait_for_events(...)` that will cause us test
- # `SlidingSyncResult.__bool__` for new results.
- self._bump_notifier_wait_for_events(
- user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA
- )
- # Block for a little bit more to ensure we don't see any new results.
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=4000)
- # Wait for the sync to complete (wait for the rest of the 10 second timeout,
- # 5000 + 4000 + 1200 > 10000)
- channel.await_result(timeout_ms=1200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- self._assert_to_device_response(channel.json_body, [])
diff --git a/tests/rest/client/sliding_sync/test_extension_typing.py b/tests/rest/client/sliding_sync/test_extension_typing.py
deleted file mode 100644
index 7f523e0f10..0000000000
--- a/tests/rest/client/sliding_sync/test_extension_typing.py
+++ /dev/null
@@ -1,482 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EduTypes
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.types import StreamKeyType
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.server import TimedOutException
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncTypingExtensionTestCase(SlidingSyncBase):
- """Tests for the typing notification sliding sync extension"""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- def test_no_data_initial_sync(self) -> None:
- """
- Test that enabling the typing extension works during an intitial sync,
- even if there is no-data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Make an initial Sliding Sync request with the typing extension enabled
- sync_body = {
- "lists": {},
- "extensions": {
- "typing": {
- "enabled": True,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertIncludes(
- response_body["extensions"]["typing"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_no_data_incremental_sync(self) -> None:
- """
- Test that enabling typing extension works during an incremental sync, even
- if there is no-data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "typing": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the typing extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- self.assertIncludes(
- response_body["extensions"]["typing"].get("rooms").keys(),
- set(),
- exact=True,
- )
-
- def test_typing_initial_sync(self) -> None:
- """
- On initial sync, we return all typing notifications for rooms that we request
- and are being returned in the Sliding Sync response.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
-
- # Create a room
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
- self.helper.join(room_id1, user4_id, tok=user4_tok)
- # User1 starts typing in room1
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id1}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User2 starts typing in room1
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id1}/typing/{user2_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Create another room
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- self.helper.join(room_id2, user3_id, tok=user3_tok)
- self.helper.join(room_id2, user4_id, tok=user4_tok)
- # User1 starts typing in room2
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id2}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User2 starts typing in room2
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id2}/typing/{user2_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Make an initial Sliding Sync request with the typing extension enabled
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- "timeline_limit": 0,
- }
- },
- "extensions": {
- "typing": {
- "enabled": True,
- "rooms": [room_id1, room_id2],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Even though we requested room2, we only expect room1 to show up because that's
- # the only room in the Sliding Sync response (room2 is not one of our room
- # subscriptions or in a sliding window list).
- self.assertIncludes(
- response_body["extensions"]["typing"].get("rooms").keys(),
- {room_id1},
- exact=True,
- )
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["typing"]["rooms"][room_id1]["type"],
- EduTypes.TYPING,
- )
- # We can see user1 and user2 typing
- self.assertIncludes(
- set(
- response_body["extensions"]["typing"]["rooms"][room_id1]["content"][
- "user_ids"
- ]
- ),
- {user1_id, user2_id},
- exact=True,
- )
-
- def test_typing_incremental_sync(self) -> None:
- """
- On incremental sync, we return all typing notifications in the token range for a
- given room but only for rooms that we request and are being returned in the
- Sliding Sync response.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
-
- # Create room1
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
- # User2 starts typing in room1
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id1}/typing/{user2_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user2_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Create room2
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- # User1 starts typing in room2 (before the `from_token`)
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id2}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Create room3
- room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id3, user1_id, tok=user1_tok)
- self.helper.join(room_id3, user3_id, tok=user3_tok)
-
- # Create room4
- room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id4, user1_id, tok=user1_tok)
- self.helper.join(room_id4, user3_id, tok=user3_tok)
- # User1 starts typing in room4 (before the `from_token`)
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id4}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Advance time so all of the typing notifications timeout before we make our
- # Sliding Sync requests. Even though these are sent before the `from_token`, the
- # typing code only keeps track of stream position of the latest typing
- # notification so "old" typing notifications that are still "alive" (haven't
- # timed out) can appear in the response.
- self.reactor.advance(36)
-
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- "timeline_limit": 0,
- },
- room_id3: {
- "required_state": [],
- "timeline_limit": 0,
- },
- room_id4: {
- "required_state": [],
- "timeline_limit": 0,
- },
- },
- "extensions": {
- "typing": {
- "enabled": True,
- "rooms": [room_id1, room_id2, room_id3, room_id4],
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Add some more typing notifications after the `from_token`
- #
- # User1 starts typing in room1
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id1}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User1 starts typing in room2
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id2}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # User3 starts typing in room3
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id3}/typing/{user3_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user3_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- # No activity for room4 after the `from_token`
-
- # Make an incremental Sliding Sync request with the typing extension enabled
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Even though we requested room2, we only expect rooms to show up if they are
- # already in the Sliding Sync response. room4 doesn't show up because there is
- # no activity after the `from_token`.
- self.assertIncludes(
- response_body["extensions"]["typing"].get("rooms").keys(),
- {room_id1, room_id3},
- exact=True,
- )
-
- # Check room1:
- #
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["typing"]["rooms"][room_id1]["type"],
- EduTypes.TYPING,
- )
- # We only see that user1 is typing in room1 since the `from_token`
- self.assertIncludes(
- set(
- response_body["extensions"]["typing"]["rooms"][room_id1]["content"][
- "user_ids"
- ]
- ),
- {user1_id},
- exact=True,
- )
-
- # Check room3:
- #
- # Sanity check that it's the correct ephemeral event type
- self.assertEqual(
- response_body["extensions"]["typing"]["rooms"][room_id3]["type"],
- EduTypes.TYPING,
- )
- # We only see that user3 is typing in room1 since the `from_token`
- self.assertIncludes(
- set(
- response_body["extensions"]["typing"]["rooms"][room_id3]["content"][
- "user_ids"
- ]
- ),
- {user3_id},
- exact=True,
- )
-
- def test_wait_for_new_data(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive.
-
- (Only applies to incremental syncs with a `timeout` specified)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {},
- "room_subscriptions": {
- room_id: {
- "required_state": [],
- "timeline_limit": 0,
- },
- },
- "extensions": {
- "typing": {
- "enabled": True,
- "rooms": [room_id],
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make an incremental Sliding Sync request with the typing extension enabled
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Bump the typing status to trigger new results
- typing_channel = self.make_request(
- "PUT",
- f"/rooms/{room_id}/typing/{user2_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user2_tok,
- )
- self.assertEqual(typing_channel.code, 200, typing_channel.json_body)
- # Should respond before the 10 second timeout
- channel.await_result(timeout_ms=3000)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # We should see the new typing notification
- self.assertIncludes(
- channel.json_body.get("extensions", {})
- .get("typing", {})
- .get("rooms", {})
- .keys(),
- {room_id},
- exact=True,
- message=str(channel.json_body),
- )
- self.assertIncludes(
- set(
- channel.json_body["extensions"]["typing"]["rooms"][room_id]["content"][
- "user_ids"
- ]
- ),
- {user2_id},
- exact=True,
- )
-
- def test_wait_for_new_data_timeout(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive but
- no data ever arrives so we timeout. We're also making sure that the default data
- from the typing extension doesn't trigger a false-positive for new data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- sync_body = {
- "lists": {},
- "extensions": {
- "typing": {
- "enabled": True,
- }
- },
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Wake-up `notifier.wait_for_events(...)` that will cause us test
- # `SlidingSyncResult.__bool__` for new results.
- self._bump_notifier_wait_for_events(
- user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA
- )
- # Block for a little bit more to ensure we don't see any new results.
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=4000)
- # Wait for the sync to complete (wait for the rest of the 10 second timeout,
- # 5000 + 4000 + 1200 > 10000)
- channel.await_result(timeout_ms=1200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- self.assertIncludes(
- channel.json_body["extensions"]["typing"].get("rooms").keys(),
- set(),
- exact=True,
- )
diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py
deleted file mode 100644
index 68f6661334..0000000000
--- a/tests/rest/client/sliding_sync/test_extensions.py
+++ /dev/null
@@ -1,283 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-from typing import Literal
-
-from parameterized import parameterized
-from typing_extensions import assert_never
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import ReceiptTypes
-from synapse.rest.client import login, receipts, room, sync
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncExtensionsTestCase(SlidingSyncBase):
- """
- Test general extensions behavior in the Sliding Sync API. Each extension has their
- own suite of tests in their own file as well.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- receipts.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
- self.account_data_handler = hs.get_account_data_handler()
-
- # Any extensions that use `lists`/`rooms` should be tested here
- @parameterized.expand([("account_data",), ("receipts",), ("typing",)])
- def test_extensions_lists_rooms_relevant_rooms(
- self,
- extension_name: Literal["account_data", "receipts", "typing"],
- ) -> None:
- """
- With various extensions, test out requesting different variations of
- `lists`/`rooms`.
-
- Stresses `SlidingSyncHandler.find_relevant_room_ids_for_extension(...)`
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create some rooms
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id4 = self.helper.create_room_as(user1_id, tok=user1_tok)
- room_id5 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- room_id_to_human_name_map = {
- room_id1: "room1",
- room_id2: "room2",
- room_id3: "room3",
- room_id4: "room4",
- room_id5: "room5",
- }
-
- for room_id in room_id_to_human_name_map.keys():
- if extension_name == "account_data":
- # Add some account data to each room
- self.get_success(
- self.account_data_handler.add_account_data_to_room(
- user_id=user1_id,
- room_id=room_id,
- account_data_type="org.matrix.roorarraz",
- content={"roo": "rar"},
- )
- )
- elif extension_name == "receipts":
- event_response = self.helper.send(
- room_id, body="new event", tok=user1_tok
- )
- # Read last event
- channel = self.make_request(
- "POST",
- f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_response['event_id']}",
- {},
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- elif extension_name == "typing":
- # Start a typing notification
- channel = self.make_request(
- "PUT",
- f"/rooms/{room_id}/typing/{user1_id}",
- b'{"typing": true, "timeout": 30000}',
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- else:
- assert_never(extension_name)
-
- main_sync_body = {
- "lists": {
- # We expect this list range to include room5 and room4
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- },
- # We expect this list range to include room5, room4, room3
- "bar-list": {
- "ranges": [[0, 2]],
- "required_state": [],
- "timeline_limit": 0,
- },
- },
- "room_subscriptions": {
- room_id1: {
- "required_state": [],
- "timeline_limit": 0,
- }
- },
- }
-
- # Mix lists and rooms
- sync_body = {
- **main_sync_body,
- "extensions": {
- extension_name: {
- "enabled": True,
- "lists": ["foo-list", "non-existent-list"],
- "rooms": [room_id1, room_id2, "!non-existent-room"],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # room1: ✅ Requested via `rooms` and a room subscription exists
- # room2: ❌ Requested via `rooms` but not in the response (from lists or room subscriptions)
- # room3: ❌ Not requested
- # room4: ✅ Shows up because requested via `lists` and list exists in the response
- # room5: ✅ Shows up because requested via `lists` and list exists in the response
- self.assertIncludes(
- {
- room_id_to_human_name_map[room_id]
- for room_id in response_body["extensions"][extension_name]
- .get("rooms")
- .keys()
- },
- {"room1", "room4", "room5"},
- exact=True,
- )
-
- # Try wildcards (this is the default)
- sync_body = {
- **main_sync_body,
- "extensions": {
- extension_name: {
- "enabled": True,
- # "lists": ["*"],
- # "rooms": ["*"],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # room1: ✅ Shows up because of default `rooms` wildcard and is in one of the room subscriptions
- # room2: ❌ Not requested
- # room3: ✅ Shows up because of default `lists` wildcard and is in a list
- # room4: ✅ Shows up because of default `lists` wildcard and is in a list
- # room5: ✅ Shows up because of default `lists` wildcard and is in a list
- self.assertIncludes(
- {
- room_id_to_human_name_map[room_id]
- for room_id in response_body["extensions"][extension_name]
- .get("rooms")
- .keys()
- },
- {"room1", "room3", "room4", "room5"},
- exact=True,
- )
-
- # Empty list will return nothing
- sync_body = {
- **main_sync_body,
- "extensions": {
- extension_name: {
- "enabled": True,
- "lists": [],
- "rooms": [],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # room1: ❌ Not requested
- # room2: ❌ Not requested
- # room3: ❌ Not requested
- # room4: ❌ Not requested
- # room5: ❌ Not requested
- self.assertIncludes(
- {
- room_id_to_human_name_map[room_id]
- for room_id in response_body["extensions"][extension_name]
- .get("rooms")
- .keys()
- },
- set(),
- exact=True,
- )
-
- # Try wildcard and none
- sync_body = {
- **main_sync_body,
- "extensions": {
- extension_name: {
- "enabled": True,
- "lists": ["*"],
- "rooms": [],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # room1: ❌ Not requested
- # room2: ❌ Not requested
- # room3: ✅ Shows up because of default `lists` wildcard and is in a list
- # room4: ✅ Shows up because of default `lists` wildcard and is in a list
- # room5: ✅ Shows up because of default `lists` wildcard and is in a list
- self.assertIncludes(
- {
- room_id_to_human_name_map[room_id]
- for room_id in response_body["extensions"][extension_name]
- .get("rooms")
- .keys()
- },
- {"room3", "room4", "room5"},
- exact=True,
- )
-
- # Try requesting a room that is only in a list
- sync_body = {
- **main_sync_body,
- "extensions": {
- extension_name: {
- "enabled": True,
- "lists": [],
- "rooms": [room_id5],
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # room1: ❌ Not requested
- # room2: ❌ Not requested
- # room3: ❌ Not requested
- # room4: ❌ Not requested
- # room5: ✅ Requested via `rooms` and is in a list
- self.assertIncludes(
- {
- room_id_to_human_name_map[room_id]
- for room_id in response_body["extensions"][extension_name]
- .get("rooms")
- .keys()
- },
- {"room5"},
- exact=True,
- )
diff --git a/tests/rest/client/sliding_sync/test_room_subscriptions.py b/tests/rest/client/sliding_sync/test_room_subscriptions.py
deleted file mode 100644
index cc17b0b354..0000000000
--- a/tests/rest/client/sliding_sync/test_room_subscriptions.py
+++ /dev/null
@@ -1,285 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-from http import HTTPStatus
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EventTypes, HistoryVisibility
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase):
- """
- Test `room_subscriptions` in the Sliding Sync API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
-
- def test_room_subscriptions_with_join_membership(self) -> None:
- """
- Test `room_subscriptions` with a joined room should give us timeline and current
- state events.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request with just the room subscription
- sync_body = {
- "room_subscriptions": {
- room_id1: {
- "required_state": [
- [EventTypes.Create, ""],
- ],
- "timeline_limit": 1,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- # We should see some state
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- # We should see some events
- self.assertEqual(
- [
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- [
- join_response["event_id"],
- ],
- response_body["rooms"][room_id1]["timeline"],
- )
- # No "live" events in an initial sync (no `from_token` to define the "live"
- # range)
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 0,
- response_body["rooms"][room_id1],
- )
- # There are more events to paginate to
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- True,
- response_body["rooms"][room_id1],
- )
-
- def test_room_subscriptions_with_leave_membership(self) -> None:
- """
- Test `room_subscriptions` with a leave room should give us timeline and state
- events up to the leave event.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
-
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- # Send some events after user1 leaves
- self.helper.send(room_id1, "activity after leave", tok=user2_tok)
- # Update state after user1 leaves
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "qux"},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request with just the room subscription
- sync_body = {
- "room_subscriptions": {
- room_id1: {
- "required_state": [
- ["org.matrix.foo_state", ""],
- ],
- "timeline_limit": 2,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # We should see the state at the time of the leave
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[("org.matrix.foo_state", "")],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- # We should see some before we left (nothing after)
- self.assertEqual(
- [
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- [
- join_response["event_id"],
- leave_response["event_id"],
- ],
- response_body["rooms"][room_id1]["timeline"],
- )
- # No "live" events in an initial sync (no `from_token` to define the "live"
- # range)
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 0,
- response_body["rooms"][room_id1],
- )
- # There are more events to paginate to
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- True,
- response_body["rooms"][room_id1],
- )
-
- def test_room_subscriptions_no_leak_private_room(self) -> None:
- """
- Test `room_subscriptions` with a private room we have never been in should not
- leak any data to the user.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=False)
-
- # We should not be able to join the private room
- self.helper.join(
- room_id1, user1_id, tok=user1_tok, expect_code=HTTPStatus.FORBIDDEN
- )
-
- # Make the Sliding Sync request with just the room subscription
- sync_body = {
- "room_subscriptions": {
- room_id1: {
- "required_state": [
- [EventTypes.Create, ""],
- ],
- "timeline_limit": 1,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # We should not see the room at all (we're not in it)
- self.assertIsNone(response_body["rooms"].get(room_id1), response_body["rooms"])
-
- def test_room_subscriptions_world_readable(self) -> None:
- """
- Test `room_subscriptions` with a room that has `world_readable` history visibility
-
- FIXME: We should be able to see the room timeline and state
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a room with `world_readable` history visibility
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "preset": "public_chat",
- "initial_state": [
- {
- "content": {
- "history_visibility": HistoryVisibility.WORLD_READABLE
- },
- "state_key": "",
- "type": EventTypes.RoomHistoryVisibility,
- }
- ],
- },
- )
- # Ensure we're testing with a room with `world_readable` history visibility
- # which means events are visible to anyone even without membership.
- history_visibility_response = self.helper.get_state(
- room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
- )
- self.assertEqual(
- history_visibility_response.get("history_visibility"),
- HistoryVisibility.WORLD_READABLE,
- )
-
- # Note: We never join the room
-
- # Make the Sliding Sync request with just the room subscription
- sync_body = {
- "room_subscriptions": {
- room_id1: {
- "required_state": [
- [EventTypes.Create, ""],
- ],
- "timeline_limit": 1,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # FIXME: In the future, we should be able to see the room because it's
- # `world_readable` but currently we don't support this.
- self.assertIsNone(response_body["rooms"].get(room_id1), response_body["rooms"])
diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py
deleted file mode 100644
index f08ffaf674..0000000000
--- a/tests/rest/client/sliding_sync/test_rooms_invites.py
+++ /dev/null
@@ -1,510 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EventTypes, HistoryVisibility
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.types import UserID
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase):
- """
- Test to make sure the `rooms` response looks good for invites in the Sliding Sync API.
-
- Invites behave a lot different than other rooms because we don't include the
- `timeline` (`num_live`, `limited`, `prev_batch`) or `required_state` in favor of
- some stripped state under the `invite_state` key.
-
- Knocks probably have the same behavior but the spec doesn't mention knocks yet.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
-
- def test_rooms_invite_shared_history_initial_sync(self) -> None:
- """
- Test that `rooms` we are invited to have some stripped `invite_state` during an
- initial sync.
-
- This is an `invite` room so we should only have `stripped_state` (no `timeline`)
- but we also shouldn't see any timeline events because the history visiblity is
- `shared` and we haven't joined the room yet.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user1 = UserID.from_string(user1_id)
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user2 = UserID.from_string(user2_id)
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Ensure we're testing with a room with `shared` history visibility which means
- # history visible until you actually join the room.
- history_visibility_response = self.helper.get_state(
- room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
- )
- self.assertEqual(
- history_visibility_response.get("history_visibility"),
- HistoryVisibility.SHARED,
- )
-
- self.helper.send(room_id1, "activity before1", tok=user2_tok)
- self.helper.send(room_id1, "activity before2", tok=user2_tok)
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.send(room_id1, "activity after3", tok=user2_tok)
- self.helper.send(room_id1, "activity after4", tok=user2_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 3,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # `timeline` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("timeline"),
- response_body["rooms"][room_id1],
- )
- # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("num_live"),
- response_body["rooms"][room_id1],
- )
- # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("limited"),
- response_body["rooms"][room_id1],
- )
- # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("prev_batch"),
- response_body["rooms"][room_id1],
- )
- # `required_state` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("required_state"),
- response_body["rooms"][room_id1],
- )
- # We should have some `stripped_state` so the potential joiner can identify the
- # room (we don't care about the order).
- self.assertCountEqual(
- response_body["rooms"][room_id1]["invite_state"],
- [
- {
- "content": {"creator": user2_id, "room_version": "10"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.create",
- },
- {
- "content": {"join_rule": "public"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.join_rules",
- },
- {
- "content": {"displayname": user2.localpart, "membership": "join"},
- "sender": user2_id,
- "state_key": user2_id,
- "type": "m.room.member",
- },
- {
- "content": {"displayname": user1.localpart, "membership": "invite"},
- "sender": user2_id,
- "state_key": user1_id,
- "type": "m.room.member",
- },
- ],
- response_body["rooms"][room_id1]["invite_state"],
- )
-
- def test_rooms_invite_shared_history_incremental_sync(self) -> None:
- """
- Test that `rooms` we are invited to have some stripped `invite_state` during an
- incremental sync.
-
- This is an `invite` room so we should only have `stripped_state` (no `timeline`)
- but we also shouldn't see any timeline events because the history visiblity is
- `shared` and we haven't joined the room yet.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user1 = UserID.from_string(user1_id)
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user2 = UserID.from_string(user2_id)
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Ensure we're testing with a room with `shared` history visibility which means
- # history visible until you actually join the room.
- history_visibility_response = self.helper.get_state(
- room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
- )
- self.assertEqual(
- history_visibility_response.get("history_visibility"),
- HistoryVisibility.SHARED,
- )
-
- self.helper.send(room_id1, "activity before invite1", tok=user2_tok)
- self.helper.send(room_id1, "activity before invite2", tok=user2_tok)
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.send(room_id1, "activity after invite3", tok=user2_tok)
- self.helper.send(room_id1, "activity after invite4", tok=user2_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 3,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- self.helper.send(room_id1, "activity after token5", tok=user2_tok)
- self.helper.send(room_id1, "activity after toekn6", tok=user2_tok)
-
- # Make the Sliding Sync request
- response_body, from_token = self.do_sync(
- sync_body, since=from_token, tok=user1_tok
- )
-
- # `timeline` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("timeline"),
- response_body["rooms"][room_id1],
- )
- # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("num_live"),
- response_body["rooms"][room_id1],
- )
- # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("limited"),
- response_body["rooms"][room_id1],
- )
- # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("prev_batch"),
- response_body["rooms"][room_id1],
- )
- # `required_state` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("required_state"),
- response_body["rooms"][room_id1],
- )
- # We should have some `stripped_state` so the potential joiner can identify the
- # room (we don't care about the order).
- self.assertCountEqual(
- response_body["rooms"][room_id1]["invite_state"],
- [
- {
- "content": {"creator": user2_id, "room_version": "10"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.create",
- },
- {
- "content": {"join_rule": "public"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.join_rules",
- },
- {
- "content": {"displayname": user2.localpart, "membership": "join"},
- "sender": user2_id,
- "state_key": user2_id,
- "type": "m.room.member",
- },
- {
- "content": {"displayname": user1.localpart, "membership": "invite"},
- "sender": user2_id,
- "state_key": user1_id,
- "type": "m.room.member",
- },
- ],
- response_body["rooms"][room_id1]["invite_state"],
- )
-
- def test_rooms_invite_world_readable_history_initial_sync(self) -> None:
- """
- Test that `rooms` we are invited to have some stripped `invite_state` during an
- initial sync.
-
- This is an `invite` room so we should only have `stripped_state` (no `timeline`)
- but depending on the semantics we decide, we could potentially see some
- historical events before/after the `from_token` because the history is
- `world_readable`. Same situation for events after the `from_token` if the
- history visibility was set to `invited`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user1 = UserID.from_string(user1_id)
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user2 = UserID.from_string(user2_id)
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "preset": "public_chat",
- "initial_state": [
- {
- "content": {
- "history_visibility": HistoryVisibility.WORLD_READABLE
- },
- "state_key": "",
- "type": EventTypes.RoomHistoryVisibility,
- }
- ],
- },
- )
- # Ensure we're testing with a room with `world_readable` history visibility
- # which means events are visible to anyone even without membership.
- history_visibility_response = self.helper.get_state(
- room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
- )
- self.assertEqual(
- history_visibility_response.get("history_visibility"),
- HistoryVisibility.WORLD_READABLE,
- )
-
- self.helper.send(room_id1, "activity before1", tok=user2_tok)
- self.helper.send(room_id1, "activity before2", tok=user2_tok)
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.send(room_id1, "activity after3", tok=user2_tok)
- self.helper.send(room_id1, "activity after4", tok=user2_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- # Large enough to see the latest events and before the invite
- "timeline_limit": 4,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # `timeline` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("timeline"),
- response_body["rooms"][room_id1],
- )
- # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("num_live"),
- response_body["rooms"][room_id1],
- )
- # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("limited"),
- response_body["rooms"][room_id1],
- )
- # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("prev_batch"),
- response_body["rooms"][room_id1],
- )
- # `required_state` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("required_state"),
- response_body["rooms"][room_id1],
- )
- # We should have some `stripped_state` so the potential joiner can identify the
- # room (we don't care about the order).
- self.assertCountEqual(
- response_body["rooms"][room_id1]["invite_state"],
- [
- {
- "content": {"creator": user2_id, "room_version": "10"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.create",
- },
- {
- "content": {"join_rule": "public"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.join_rules",
- },
- {
- "content": {"displayname": user2.localpart, "membership": "join"},
- "sender": user2_id,
- "state_key": user2_id,
- "type": "m.room.member",
- },
- {
- "content": {"displayname": user1.localpart, "membership": "invite"},
- "sender": user2_id,
- "state_key": user1_id,
- "type": "m.room.member",
- },
- ],
- response_body["rooms"][room_id1]["invite_state"],
- )
-
- def test_rooms_invite_world_readable_history_incremental_sync(self) -> None:
- """
- Test that `rooms` we are invited to have some stripped `invite_state` during an
- incremental sync.
-
- This is an `invite` room so we should only have `stripped_state` (no `timeline`)
- but depending on the semantics we decide, we could potentially see some
- historical events before/after the `from_token` because the history is
- `world_readable`. Same situation for events after the `from_token` if the
- history visibility was set to `invited`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user1 = UserID.from_string(user1_id)
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user2 = UserID.from_string(user2_id)
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "preset": "public_chat",
- "initial_state": [
- {
- "content": {
- "history_visibility": HistoryVisibility.WORLD_READABLE
- },
- "state_key": "",
- "type": EventTypes.RoomHistoryVisibility,
- }
- ],
- },
- )
- # Ensure we're testing with a room with `world_readable` history visibility
- # which means events are visible to anyone even without membership.
- history_visibility_response = self.helper.get_state(
- room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok
- )
- self.assertEqual(
- history_visibility_response.get("history_visibility"),
- HistoryVisibility.WORLD_READABLE,
- )
-
- self.helper.send(room_id1, "activity before invite1", tok=user2_tok)
- self.helper.send(room_id1, "activity before invite2", tok=user2_tok)
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- self.helper.send(room_id1, "activity after invite3", tok=user2_tok)
- self.helper.send(room_id1, "activity after invite4", tok=user2_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- # Large enough to see the latest events and before the invite
- "timeline_limit": 4,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- self.helper.send(room_id1, "activity after token5", tok=user2_tok)
- self.helper.send(room_id1, "activity after toekn6", tok=user2_tok)
-
- # Make the incremental Sliding Sync request
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # `timeline` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("timeline"),
- response_body["rooms"][room_id1],
- )
- # `num_live` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("num_live"),
- response_body["rooms"][room_id1],
- )
- # `limited` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("limited"),
- response_body["rooms"][room_id1],
- )
- # `prev_batch` is omitted for `invite` rooms with `stripped_state` (no timeline anyway)
- self.assertIsNone(
- response_body["rooms"][room_id1].get("prev_batch"),
- response_body["rooms"][room_id1],
- )
- # `required_state` is omitted for `invite` rooms with `stripped_state`
- self.assertIsNone(
- response_body["rooms"][room_id1].get("required_state"),
- response_body["rooms"][room_id1],
- )
- # We should have some `stripped_state` so the potential joiner can identify the
- # room (we don't care about the order).
- self.assertCountEqual(
- response_body["rooms"][room_id1]["invite_state"],
- [
- {
- "content": {"creator": user2_id, "room_version": "10"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.create",
- },
- {
- "content": {"join_rule": "public"},
- "sender": user2_id,
- "state_key": "",
- "type": "m.room.join_rules",
- },
- {
- "content": {"displayname": user2.localpart, "membership": "join"},
- "sender": user2_id,
- "state_key": user2_id,
- "type": "m.room.member",
- },
- {
- "content": {"displayname": user1.localpart, "membership": "invite"},
- "sender": user2_id,
- "state_key": user1_id,
- "type": "m.room.member",
- },
- ],
- response_body["rooms"][room_id1]["invite_state"],
- )
diff --git a/tests/rest/client/sliding_sync/test_rooms_meta.py b/tests/rest/client/sliding_sync/test_rooms_meta.py
deleted file mode 100644
index 04f11c0524..0000000000
--- a/tests/rest/client/sliding_sync/test_rooms_meta.py
+++ /dev/null
@@ -1,710 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.room_versions import RoomVersions
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.test_utils.event_injection import create_event
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
- """
- Test rooms meta info like name, avatar, joined_count, invited_count, is_dm,
- bump_stamp in the Sliding Sync API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
-
- def test_rooms_meta_when_joined(self) -> None:
- """
- Test that the `rooms` `name` and `avatar` are included in the response and
- reflect the current state of the room when the user is joined to the room.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "name": "my super room",
- },
- )
- # Set the room avatar URL
- self.helper.send_state(
- room_id1,
- EventTypes.RoomAvatar,
- {"url": "mxc://DUMMY_MEDIA_ID"},
- tok=user2_tok,
- )
-
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Reflect the current state of the room
- self.assertEqual(
- response_body["rooms"][room_id1]["name"],
- "my super room",
- response_body["rooms"][room_id1],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["avatar"],
- "mxc://DUMMY_MEDIA_ID",
- response_body["rooms"][room_id1],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- 2,
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 0,
- )
- self.assertIsNone(
- response_body["rooms"][room_id1].get("is_dm"),
- )
-
- def test_rooms_meta_when_invited(self) -> None:
- """
- Test that the `rooms` `name` and `avatar` are included in the response and
- reflect the current state of the room when the user is invited to the room.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "name": "my super room",
- },
- )
- # Set the room avatar URL
- self.helper.send_state(
- room_id1,
- EventTypes.RoomAvatar,
- {"url": "mxc://DUMMY_MEDIA_ID"},
- tok=user2_tok,
- )
-
- # User1 is invited to the room
- self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
-
- # Update the room name after user1 has left
- self.helper.send_state(
- room_id1,
- EventTypes.Name,
- {"name": "my super duper room"},
- tok=user2_tok,
- )
- # Update the room avatar URL after user1 has left
- self.helper.send_state(
- room_id1,
- EventTypes.RoomAvatar,
- {"url": "mxc://UPDATED_DUMMY_MEDIA_ID"},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # This should still reflect the current state of the room even when the user is
- # invited.
- self.assertEqual(
- response_body["rooms"][room_id1]["name"],
- "my super duper room",
- response_body["rooms"][room_id1],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["avatar"],
- "mxc://UPDATED_DUMMY_MEDIA_ID",
- response_body["rooms"][room_id1],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- 1,
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 1,
- )
- self.assertIsNone(
- response_body["rooms"][room_id1].get("is_dm"),
- )
-
- def test_rooms_meta_when_banned(self) -> None:
- """
- Test that the `rooms` `name` and `avatar` reflect the state of the room when the
- user was banned (do not leak current state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "name": "my super room",
- },
- )
- # Set the room avatar URL
- self.helper.send_state(
- room_id1,
- EventTypes.RoomAvatar,
- {"url": "mxc://DUMMY_MEDIA_ID"},
- tok=user2_tok,
- )
-
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
-
- # Update the room name after user1 has left
- self.helper.send_state(
- room_id1,
- EventTypes.Name,
- {"name": "my super duper room"},
- tok=user2_tok,
- )
- # Update the room avatar URL after user1 has left
- self.helper.send_state(
- room_id1,
- EventTypes.RoomAvatar,
- {"url": "mxc://UPDATED_DUMMY_MEDIA_ID"},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Reflect the state of the room at the time of leaving
- self.assertEqual(
- response_body["rooms"][room_id1]["name"],
- "my super room",
- response_body["rooms"][room_id1],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["avatar"],
- "mxc://DUMMY_MEDIA_ID",
- response_body["rooms"][room_id1],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- # FIXME: The actual number should be "1" (user2) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 0,
- )
- self.assertIsNone(
- response_body["rooms"][room_id1].get("is_dm"),
- )
-
- def test_rooms_meta_heroes(self) -> None:
- """
- Test that the `rooms` `heroes` are included in the response when the room
- doesn't have a room name set.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- _user3_tok = self.login(user3_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "name": "my super room",
- },
- )
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- # User3 is invited
- self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok)
-
- room_id2 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- # No room name set so that `heroes` is populated
- #
- # "name": "my super room2",
- },
- )
- self.helper.join(room_id2, user1_id, tok=user1_tok)
- # User3 is invited
- self.helper.invite(room_id2, src=user2_id, targ=user3_id, tok=user2_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Room1 has a name so we shouldn't see any `heroes` which the client would use
- # the calculate the room name themselves.
- self.assertEqual(
- response_body["rooms"][room_id1]["name"],
- "my super room",
- response_body["rooms"][room_id1],
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("heroes"))
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- 2,
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 1,
- )
-
- # Room2 doesn't have a name so we should see `heroes` populated
- self.assertIsNone(response_body["rooms"][room_id2].get("name"))
- self.assertCountEqual(
- [
- hero["user_id"]
- for hero in response_body["rooms"][room_id2].get("heroes", [])
- ],
- # Heroes shouldn't include the user themselves (we shouldn't see user1)
- [user2_id, user3_id],
- )
- self.assertEqual(
- response_body["rooms"][room_id2]["joined_count"],
- 2,
- )
- self.assertEqual(
- response_body["rooms"][room_id2]["invited_count"],
- 1,
- )
-
- # We didn't request any state so we shouldn't see any `required_state`
- self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
- self.assertIsNone(response_body["rooms"][room_id2].get("required_state"))
-
- def test_rooms_meta_heroes_max(self) -> None:
- """
- Test that the `rooms` `heroes` only includes the first 5 users (not including
- yourself).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
- user5_id = self.register_user("user5", "pass")
- user5_tok = self.login(user5_id, "pass")
- user6_id = self.register_user("user6", "pass")
- user6_tok = self.login(user6_id, "pass")
- user7_id = self.register_user("user7", "pass")
- user7_tok = self.login(user7_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- # No room name set so that `heroes` is populated
- #
- # "name": "my super room",
- },
- )
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
- self.helper.join(room_id1, user4_id, tok=user4_tok)
- self.helper.join(room_id1, user5_id, tok=user5_tok)
- self.helper.join(room_id1, user6_id, tok=user6_tok)
- self.helper.join(room_id1, user7_id, tok=user7_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Room2 doesn't have a name so we should see `heroes` populated
- self.assertIsNone(response_body["rooms"][room_id1].get("name"))
- self.assertCountEqual(
- [
- hero["user_id"]
- for hero in response_body["rooms"][room_id1].get("heroes", [])
- ],
- # Heroes should be the first 5 users in the room (excluding the user
- # themselves, we shouldn't see `user1`)
- [user2_id, user3_id, user4_id, user5_id, user6_id],
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- 7,
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 0,
- )
-
- # We didn't request any state so we shouldn't see any `required_state`
- self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
-
- def test_rooms_meta_heroes_when_banned(self) -> None:
- """
- Test that the `rooms` `heroes` are included in the response when the room
- doesn't have a room name set but doesn't leak information past their ban.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- _user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
- user5_id = self.register_user("user5", "pass")
- _user5_tok = self.login(user5_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- # No room name set so that `heroes` is populated
- #
- # "name": "my super room",
- },
- )
- # User1 joins the room
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- # User3 is invited
- self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok)
-
- # User1 is banned from the room
- self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
-
- # User4 joins the room after user1 is banned
- self.helper.join(room_id1, user4_id, tok=user4_tok)
- # User5 is invited after user1 is banned
- self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Room2 doesn't have a name so we should see `heroes` populated
- self.assertIsNone(response_body["rooms"][room_id1].get("name"))
- self.assertCountEqual(
- [
- hero["user_id"]
- for hero in response_body["rooms"][room_id1].get("heroes", [])
- ],
- # Heroes shouldn't include the user themselves (we shouldn't see user1). We
- # also shouldn't see user4 since they joined after user1 was banned.
- #
- # FIXME: The actual result should be `[user2_id, user3_id]` but we currently
- # don't support this for rooms where the user has left/been banned.
- [],
- )
-
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- # FIXME: The actual number should be "1" (user2) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
- )
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- # We shouldn't see user5 since they were invited after user1 was banned.
- #
- # FIXME: The actual number should be "1" (user3) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
- )
-
- def test_rooms_bump_stamp(self) -> None:
- """
- Test that `bump_stamp` is present and pointing to relevant events.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- )
- event_response1 = message_response = self.helper.send(
- room_id1, "message in room1", tok=user1_tok
- )
- event_pos1 = self.get_success(
- self.store.get_position_for_event(event_response1["event_id"])
- )
- room_id2 = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- )
- send_response2 = self.helper.send(room_id2, "message in room2", tok=user1_tok)
- event_pos2 = self.get_success(
- self.store.get_position_for_event(send_response2["event_id"])
- )
-
- # Send a reaction in room1 but it shouldn't affect the `bump_stamp`
- # because reactions are not part of the `DEFAULT_BUMP_EVENT_TYPES`
- self.helper.send_event(
- room_id1,
- type=EventTypes.Reaction,
- content={
- "m.relates_to": {
- "event_id": message_response["event_id"],
- "key": "👍",
- "rel_type": "m.annotation",
- }
- },
- tok=user1_tok,
- )
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 100,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
- response_body["lists"].keys(),
- )
-
- # Make sure the list includes the rooms in the right order
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- # room1 sorts before room2 because it has the latest event (the
- # reaction)
- "room_ids": [room_id1, room_id2],
- }
- ],
- response_body["lists"]["foo-list"],
- )
-
- # The `bump_stamp` for room1 should point at the latest message (not the
- # reaction since it's not one of the `DEFAULT_BUMP_EVENT_TYPES`)
- self.assertEqual(
- response_body["rooms"][room_id1]["bump_stamp"],
- event_pos1.stream,
- response_body["rooms"][room_id1],
- )
-
- # The `bump_stamp` for room2 should point at the latest message
- self.assertEqual(
- response_body["rooms"][room_id2]["bump_stamp"],
- event_pos2.stream,
- response_body["rooms"][room_id2],
- )
-
- def test_rooms_bump_stamp_backfill(self) -> None:
- """
- Test that `bump_stamp` ignores backfilled events, i.e. events with a
- negative stream ordering.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote room
- creator = "@user:other"
- room_id = "!foo:other"
- shared_kwargs = {
- "room_id": room_id,
- "room_version": "10",
- }
-
- create_tuple = self.get_success(
- create_event(
- self.hs,
- prev_event_ids=[],
- type=EventTypes.Create,
- state_key="",
- sender=creator,
- **shared_kwargs,
- )
- )
- creator_tuple = self.get_success(
- create_event(
- self.hs,
- prev_event_ids=[create_tuple[0].event_id],
- auth_event_ids=[create_tuple[0].event_id],
- type=EventTypes.Member,
- state_key=creator,
- content={"membership": Membership.JOIN},
- sender=creator,
- **shared_kwargs,
- )
- )
- # We add a message event as a valid "bump type"
- msg_tuple = self.get_success(
- create_event(
- self.hs,
- prev_event_ids=[creator_tuple[0].event_id],
- auth_event_ids=[create_tuple[0].event_id],
- type=EventTypes.Message,
- content={"body": "foo", "msgtype": "m.text"},
- sender=creator,
- **shared_kwargs,
- )
- )
- invite_tuple = self.get_success(
- create_event(
- self.hs,
- prev_event_ids=[msg_tuple[0].event_id],
- auth_event_ids=[create_tuple[0].event_id, creator_tuple[0].event_id],
- type=EventTypes.Member,
- state_key=user1_id,
- content={"membership": Membership.INVITE},
- sender=creator,
- **shared_kwargs,
- )
- )
-
- remote_events_and_contexts = [
- create_tuple,
- creator_tuple,
- msg_tuple,
- invite_tuple,
- ]
-
- # Ensure the local HS knows the room version
- self.get_success(
- self.store.store_room(room_id, creator, False, RoomVersions.V10)
- )
-
- # Persist these events as backfilled events.
- persistence = self.hs.get_storage_controllers().persistence
- assert persistence is not None
-
- for event, context in remote_events_and_contexts:
- self.get_success(persistence.persist_event(event, context, backfilled=True))
-
- # Now we join the local user to the room
- join_tuple = self.get_success(
- create_event(
- self.hs,
- prev_event_ids=[invite_tuple[0].event_id],
- auth_event_ids=[create_tuple[0].event_id, invite_tuple[0].event_id],
- type=EventTypes.Member,
- state_key=user1_id,
- content={"membership": Membership.JOIN},
- sender=user1_id,
- **shared_kwargs,
- )
- )
- self.get_success(persistence.persist_event(*join_tuple))
-
- # Doing an SS request should return a positive `bump_stamp`, even though
- # the only event that matches the bump types has as negative stream
- # ordering.
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 5,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- self.assertGreater(response_body["rooms"][room_id]["bump_stamp"], 0)
diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py
deleted file mode 100644
index a13cad223f..0000000000
--- a/tests/rest/client/sliding_sync/test_rooms_required_state.py
+++ /dev/null
@@ -1,707 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-
-from parameterized import parameterized
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
-from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-from tests.test_utils.event_injection import mark_event_as_partial_state
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
- """
- Test `rooms.required_state` in the Sliding Sync API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
-
- def test_rooms_no_required_state(self) -> None:
- """
- Empty `rooms.required_state` should not return any state events in the room
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- # Empty `required_state`
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # No `required_state` in response
- self.assertIsNone(
- response_body["rooms"][room_id1].get("required_state"),
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_required_state_initial_sync(self) -> None:
- """
- Test `rooms.required_state` returns requested state events in the room during an
- initial sync.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.RoomHistoryVisibility, ""],
- # This one doesn't exist in the room
- [EventTypes.Tombstone, ""],
- ],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.RoomHistoryVisibility, "")],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_incremental_sync(self) -> None:
- """
- Test `rooms.required_state` returns requested state events in the room during an
- incremental sync.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.RoomHistoryVisibility, ""],
- # This one doesn't exist in the room
- [EventTypes.Tombstone, ""],
- ],
- "timeline_limit": 1,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Send a message so the room comes down sync.
- self.helper.send(room_id1, "msg", tok=user1_tok)
-
- # Make the incremental Sliding Sync request
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # We only return updates but only if we've sent the room down the
- # connection before.
- self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_incremental_sync_restart(self) -> None:
- """
- Test that after a restart (and so the in memory caches are reset) that
- we correctly return an `M_UNKNOWN_POS`
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.RoomHistoryVisibility, ""],
- # This one doesn't exist in the room
- [EventTypes.Tombstone, ""],
- ],
- "timeline_limit": 1,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Reset the in-memory cache
- self.hs.get_sliding_sync_handler().connection_store._connections.clear()
-
- # Make the Sliding Sync request
- channel = self.make_request(
- method="POST",
- path=self.sync_endpoint + f"?pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 400, channel.json_body)
- self.assertEqual(
- channel.json_body["errcode"], "M_UNKNOWN_POS", channel.json_body
- )
-
- def test_rooms_required_state_wildcard(self) -> None:
- """
- Test `rooms.required_state` returns all state events when using wildcard `["*", "*"]`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="namespaced",
- body={"foo": "bar"},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request with wildcards for the `event_type` and `state_key`
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [StateValues.WILDCARD, StateValues.WILDCARD],
- ],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- # We should see all the state events in the room
- state_map.values(),
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_wildcard_event_type(self) -> None:
- """
- Test `rooms.required_state` returns relevant state events when using wildcard in
- the event_type `["*", "foobarbaz"]`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key=user2_id,
- body={"foo": "bar"},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request with wildcards for the `event_type`
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [StateValues.WILDCARD, user2_id],
- ],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- # We expect at-least any state event with the `user2_id` as the `state_key`
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Member, user2_id)],
- state_map[("org.matrix.foo_state", user2_id)],
- },
- # Ideally, this would be exact but we're currently returning all state
- # events when the `event_type` is a wildcard.
- exact=False,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_wildcard_state_key(self) -> None:
- """
- Test `rooms.required_state` returns relevant state events when using wildcard in
- the state_key `["foobarbaz","*"]`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request with wildcards for the `state_key`
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Member, StateValues.WILDCARD],
- ],
- "timeline_limit": 0,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Member, user1_id)],
- state_map[(EventTypes.Member, user2_id)],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_lazy_loading_room_members(self) -> None:
- """
- Test `rooms.required_state` returns people relevant to the timeline when
- lazy-loading room members, `["m.room.member","$LAZY"]`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
-
- self.helper.send(room_id1, "1", tok=user2_tok)
- self.helper.send(room_id1, "2", tok=user3_tok)
- self.helper.send(room_id1, "3", tok=user2_tok)
-
- # Make the Sliding Sync request with lazy loading for the room members
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.Member, StateValues.LAZY],
- ],
- "timeline_limit": 3,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- # Only user2 and user3 sent events in the 3 events we see in the `timeline`
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.Member, user2_id)],
- state_map[(EventTypes.Member, user3_id)],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_me(self) -> None:
- """
- Test `rooms.required_state` correctly handles $ME.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- self.helper.send(room_id1, "1", tok=user2_tok)
-
- # Also send normal state events with state keys of the users, first
- # change the power levels to allow this.
- self.helper.send_state(
- room_id1,
- event_type=EventTypes.PowerLevels,
- body={"users": {user1_id: 50, user2_id: 100}},
- tok=user2_tok,
- )
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo",
- state_key=user1_id,
- body={},
- tok=user1_tok,
- )
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo",
- state_key=user2_id,
- body={},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request with a request for '$ME'.
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.Member, StateValues.ME],
- ["org.matrix.foo", StateValues.ME],
- ],
- "timeline_limit": 3,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- # Only user2 and user3 sent events in the 3 events we see in the `timeline`
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.Member, user1_id)],
- state_map[("org.matrix.foo", user1_id)],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- @parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
- def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None:
- """
- Test `rooms.required_state` should not return state past a leave/ban event.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.Member, "*"],
- ["org.matrix.foo_state", ""],
- ],
- "timeline_limit": 3,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.join(room_id1, user3_id, tok=user3_tok)
-
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
-
- if stop_membership == Membership.LEAVE:
- # User 1 leaves
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
- elif stop_membership == Membership.BAN:
- # User 1 is banned
- self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- # Change the state after user 1 leaves
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "qux"},
- tok=user2_tok,
- )
- self.helper.leave(room_id1, user3_id, tok=user3_tok)
-
- # Make the Sliding Sync request with lazy loading for the room members
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Only user2 and user3 sent events in the 3 events we see in the `timeline`
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.Member, user1_id)],
- state_map[(EventTypes.Member, user2_id)],
- state_map[(EventTypes.Member, user3_id)],
- state_map[("org.matrix.foo_state", "")],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_combine_superset(self) -> None:
- """
- Test `rooms.required_state` is combined across lists and room subscriptions.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.foo_state",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- self.helper.send_state(
- room_id1,
- event_type="org.matrix.bar_state",
- state_key="",
- body={"bar": "qux"},
- tok=user2_tok,
- )
-
- # Make the Sliding Sync request with wildcards for the `state_key`
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- [EventTypes.Member, user1_id],
- ],
- "timeline_limit": 0,
- },
- "bar-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Member, StateValues.WILDCARD],
- ["org.matrix.foo_state", ""],
- ],
- "timeline_limit": 0,
- },
- },
- "room_subscriptions": {
- room_id1: {
- "required_state": [["org.matrix.bar_state", ""]],
- "timeline_limit": 0,
- }
- },
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- state_map = self.get_success(
- self.storage_controllers.state.get_current_state(room_id1)
- )
-
- self._assertRequiredStateIncludes(
- response_body["rooms"][room_id1]["required_state"],
- {
- state_map[(EventTypes.Create, "")],
- state_map[(EventTypes.Member, user1_id)],
- state_map[(EventTypes.Member, user2_id)],
- state_map[("org.matrix.foo_state", "")],
- state_map[("org.matrix.bar_state", "")],
- },
- exact=True,
- )
- self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
-
- def test_rooms_required_state_partial_state(self) -> None:
- """
- Test partially-stated room are excluded unless `rooms.required_state` is
- lazy-loading room members.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- _join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- # Mark room2 as partial state
- self.get_success(
- mark_event_as_partial_state(self.hs, join_response2["event_id"], room_id2)
- )
-
- # Make the Sliding Sync request (NOT lazy-loading room members)
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- ],
- "timeline_limit": 0,
- },
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure the list includes room1 but room2 is excluded because it's still
- # partially-stated
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id1],
- }
- ],
- response_body["lists"]["foo-list"],
- )
-
- # Make the Sliding Sync request (with lazy-loading room members)
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [
- [EventTypes.Create, ""],
- # Lazy-load room members
- [EventTypes.Member, StateValues.LAZY],
- ],
- "timeline_limit": 0,
- },
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # The list should include both rooms now because we're lazy-loading room members
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id2, room_id1],
- }
- ],
- response_body["lists"]["foo-list"],
- )
diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py
deleted file mode 100644
index 2e9586ca73..0000000000
--- a/tests/rest/client/sliding_sync/test_rooms_timeline.py
+++ /dev/null
@@ -1,575 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-from typing import List, Optional
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.rest.client import login, room, sync
-from synapse.server import HomeServer
-from synapse.types import StreamToken, StrSequence
-from synapse.util import Clock
-
-from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
- """
- Test `rooms.timeline` in the Sliding Sync API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.storage_controllers = hs.get_storage_controllers()
-
- def _assertListEqual(
- self,
- actual_items: StrSequence,
- expected_items: StrSequence,
- message: Optional[str] = None,
- ) -> None:
- """
- Like `self.assertListEqual(...)` but with an actually understandable diff message.
- """
-
- if actual_items == expected_items:
- return
-
- expected_lines: List[str] = []
- for expected_item in expected_items:
- is_expected_in_actual = expected_item in actual_items
- expected_lines.append(
- "{} {}".format(" " if is_expected_in_actual else "?", expected_item)
- )
-
- actual_lines: List[str] = []
- for actual_item in actual_items:
- is_actual_in_expected = actual_item in expected_items
- actual_lines.append(
- "{} {}".format("+" if is_actual_in_expected else " ", actual_item)
- )
-
- newline = "\n"
- expected_string = f"Expected items to be in actual ('?' = missing expected items):\n [\n{newline.join(expected_lines)}\n ]"
- actual_string = f"Actual ('+' = found expected items):\n [\n{newline.join(actual_lines)}\n ]"
- first_message = "Items must"
- diff_message = f"{first_message}\n{expected_string}\n{actual_string}"
-
- self.fail(f"{diff_message}\n{message}")
-
- def _assertTimelineEqual(
- self,
- *,
- room_id: str,
- actual_event_ids: List[str],
- expected_event_ids: List[str],
- message: Optional[str] = None,
- ) -> None:
- """
- Like `self.assertListEqual(...)` for event IDs in a room but will give a nicer
- output with context for what each event_id is (type, stream_ordering, content,
- etc).
- """
- if actual_event_ids == expected_event_ids:
- return
-
- event_id_set = set(actual_event_ids + expected_event_ids)
- events = self.get_success(self.store.get_events(event_id_set))
-
- def event_id_to_string(event_id: str) -> str:
- event = events.get(event_id)
- if event:
- state_key = event.get_state_key()
- state_key_piece = f", {state_key}" if state_key is not None else ""
- return (
- f"({event.internal_metadata.stream_ordering: >2}, {event.internal_metadata.instance_name}) "
- + f"{event.event_id} ({event.type}{state_key_piece}) {event.content.get('membership', '')}{event.content.get('body', '')}"
- )
-
- return f"{event_id} <event not found in room_id={room_id}>"
-
- self._assertListEqual(
- actual_items=[
- event_id_to_string(event_id) for event_id in actual_event_ids
- ],
- expected_items=[
- event_id_to_string(event_id) for event_id in expected_event_ids
- ],
- message=message,
- )
-
- def test_rooms_limited_initial_sync(self) -> None:
- """
- Test that we mark `rooms` as `limited=True` when we saturate the `timeline_limit`
- on initial sync.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity1", tok=user2_tok)
- self.helper.send(room_id1, "activity2", tok=user2_tok)
- event_response3 = self.helper.send(room_id1, "activity3", tok=user2_tok)
- event_pos3 = self.get_success(
- self.store.get_position_for_event(event_response3["event_id"])
- )
- event_response4 = self.helper.send(room_id1, "activity4", tok=user2_tok)
- event_pos4 = self.get_success(
- self.store.get_position_for_event(event_response4["event_id"])
- )
- event_response5 = self.helper.send(room_id1, "activity5", tok=user2_tok)
- user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 3,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # We expect to saturate the `timeline_limit` (there are more than 3 messages in the room)
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- True,
- response_body["rooms"][room_id1],
- )
- # Check to make sure the latest events are returned
- self._assertTimelineEqual(
- room_id=room_id1,
- actual_event_ids=[
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- expected_event_ids=[
- event_response4["event_id"],
- event_response5["event_id"],
- user1_join_response["event_id"],
- ],
- message=str(response_body["rooms"][room_id1]["timeline"]),
- )
-
- # Check to make sure the `prev_batch` points at the right place
- prev_batch_token = self.get_success(
- StreamToken.from_string(
- self.store, response_body["rooms"][room_id1]["prev_batch"]
- )
- )
- prev_batch_room_stream_token_serialized = self.get_success(
- prev_batch_token.room_key.to_string(self.store)
- )
- # If we use the `prev_batch` token to look backwards, we should see `event3`
- # next so make sure the token encompasses it
- self.assertEqual(
- event_pos3.persisted_after(prev_batch_token.room_key),
- False,
- f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be >= event_pos3={self.get_success(event_pos3.to_room_stream_token().to_string(self.store))}",
- )
- # If we use the `prev_batch` token to look backwards, we shouldn't see `event4`
- # anymore since it was just returned in this response.
- self.assertEqual(
- event_pos4.persisted_after(prev_batch_token.room_key),
- True,
- f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be < event_pos4={self.get_success(event_pos4.to_room_stream_token().to_string(self.store))}",
- )
-
- # With no `from_token` (initial sync), it's all historical since there is no
- # "live" range
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 0,
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_not_limited_initial_sync(self) -> None:
- """
- Test that we mark `rooms` as `limited=False` when there are no more events to
- paginate to.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity1", tok=user2_tok)
- self.helper.send(room_id1, "activity2", tok=user2_tok)
- self.helper.send(room_id1, "activity3", tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Make the Sliding Sync request
- timeline_limit = 100
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": timeline_limit,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # The timeline should be `limited=False` because we have all of the events (no
- # more to paginate to)
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- False,
- response_body["rooms"][room_id1],
- )
- expected_number_of_events = 9
- # We're just looking to make sure we got all of the events before hitting the `timeline_limit`
- self.assertEqual(
- len(response_body["rooms"][room_id1]["timeline"]),
- expected_number_of_events,
- response_body["rooms"][room_id1]["timeline"],
- )
- self.assertLessEqual(expected_number_of_events, timeline_limit)
-
- # With no `from_token` (initial sync), it's all historical since there is no
- # "live" token range.
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 0,
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_incremental_sync(self) -> None:
- """
- Test `rooms` data during an incremental sync after an initial sync.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- self.helper.send(room_id1, "activity before initial sync1", tok=user2_tok)
-
- # Make an initial Sliding Sync request to grab a token. This is also a sanity
- # check that we can go from initial to incremental sync.
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 3,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Send some events but don't send enough to saturate the `timeline_limit`.
- # We want to later test that we only get the new events since the `next_pos`
- event_response2 = self.helper.send(room_id1, "activity after2", tok=user2_tok)
- event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
-
- # Make an incremental Sliding Sync request (what we're trying to test)
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # We only expect to see the new events since the last sync which isn't enough to
- # fill up the `timeline_limit`.
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- False,
- f'Our `timeline_limit` was {sync_body["lists"]["foo-list"]["timeline_limit"]} '
- + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
- + str(response_body["rooms"][room_id1]),
- )
- # Check to make sure the latest events are returned
- self._assertTimelineEqual(
- room_id=room_id1,
- actual_event_ids=[
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- expected_event_ids=[
- event_response2["event_id"],
- event_response3["event_id"],
- ],
- message=str(response_body["rooms"][room_id1]["timeline"]),
- )
-
- # All events are "live"
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 2,
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_newly_joined_incremental_sync(self) -> None:
- """
- Test that when we make an incremental sync with a `newly_joined` `rooms`, we are
- able to see some historical events before the `from_token`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity before token1", tok=user2_tok)
- event_response2 = self.helper.send(
- room_id1, "activity before token2", tok=user2_tok
- )
-
- # The `timeline_limit` is set to 4 so we can at least see one historical event
- # before the `from_token`. We should see historical events because this is a
- # `newly_joined` room.
- timeline_limit = 4
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": timeline_limit,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Join the room after the `from_token` which will make us consider this room as
- # `newly_joined`.
- user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- # Send some events but don't send enough to saturate the `timeline_limit`.
- # We want to later test that we only get the new events since the `next_pos`
- event_response3 = self.helper.send(
- room_id1, "activity after token3", tok=user2_tok
- )
- event_response4 = self.helper.send(
- room_id1, "activity after token4", tok=user2_tok
- )
-
- # Make an incremental Sliding Sync request (what we're trying to test)
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # We should see the new events and the rest should be filled with historical
- # events which will make us `limited=True` since there are more to paginate to.
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- True,
- f"Our `timeline_limit` was {timeline_limit} "
- + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
- + str(response_body["rooms"][room_id1]),
- )
- # Check to make sure that the "live" and historical events are returned
- self._assertTimelineEqual(
- room_id=room_id1,
- actual_event_ids=[
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- expected_event_ids=[
- event_response2["event_id"],
- user1_join_response["event_id"],
- event_response3["event_id"],
- event_response4["event_id"],
- ],
- message=str(response_body["rooms"][room_id1]["timeline"]),
- )
-
- # Only events after the `from_token` are "live" (join, event3, event4)
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 3,
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_ban_initial_sync(self) -> None:
- """
- Test that `rooms` we are banned from in an intial sync only allows us to see
- timeline events up to the ban event.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity before1", tok=user2_tok)
- self.helper.send(room_id1, "activity before2", tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
- event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok)
- user1_ban_response = self.helper.ban(
- room_id1, src=user2_id, targ=user1_id, tok=user2_tok
- )
-
- self.helper.send(room_id1, "activity after5", tok=user2_tok)
- self.helper.send(room_id1, "activity after6", tok=user2_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 3,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # We should see events before the ban but not after
- self._assertTimelineEqual(
- room_id=room_id1,
- actual_event_ids=[
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- expected_event_ids=[
- event_response3["event_id"],
- event_response4["event_id"],
- user1_ban_response["event_id"],
- ],
- message=str(response_body["rooms"][room_id1]["timeline"]),
- )
- # No "live" events in an initial sync (no `from_token` to define the "live"
- # range)
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 0,
- response_body["rooms"][room_id1],
- )
- # There are more events to paginate to
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- True,
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_ban_incremental_sync1(self) -> None:
- """
- Test that `rooms` we are banned from during the next incremental sync only
- allows us to see timeline events up to the ban event.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity before1", tok=user2_tok)
- self.helper.send(room_id1, "activity before2", tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 4,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- event_response3 = self.helper.send(room_id1, "activity after3", tok=user2_tok)
- event_response4 = self.helper.send(room_id1, "activity after4", tok=user2_tok)
- # The ban is within the token range (between the `from_token` and the sliding
- # sync request)
- user1_ban_response = self.helper.ban(
- room_id1, src=user2_id, targ=user1_id, tok=user2_tok
- )
-
- self.helper.send(room_id1, "activity after5", tok=user2_tok)
- self.helper.send(room_id1, "activity after6", tok=user2_tok)
-
- # Make the incremental Sliding Sync request
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # We should see events before the ban but not after
- self._assertTimelineEqual(
- room_id=room_id1,
- actual_event_ids=[
- event["event_id"]
- for event in response_body["rooms"][room_id1]["timeline"]
- ],
- expected_event_ids=[
- event_response3["event_id"],
- event_response4["event_id"],
- user1_ban_response["event_id"],
- ],
- message=str(response_body["rooms"][room_id1]["timeline"]),
- )
- # All live events in the incremental sync
- self.assertEqual(
- response_body["rooms"][room_id1]["num_live"],
- 3,
- response_body["rooms"][room_id1],
- )
- # There aren't anymore events to paginate to in this range
- self.assertEqual(
- response_body["rooms"][room_id1]["limited"],
- False,
- response_body["rooms"][room_id1],
- )
-
- def test_rooms_ban_incremental_sync2(self) -> None:
- """
- Test that `rooms` we are banned from before the incremental sync don't return
- any events in the timeline.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity before1", tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
-
- self.helper.send(room_id1, "activity after2", tok=user2_tok)
- # The ban is before we get our `from_token`
- self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
-
- self.helper.send(room_id1, "activity after3", tok=user2_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 4,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- self.helper.send(room_id1, "activity after4", tok=user2_tok)
-
- # Make the incremental Sliding Sync request
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Nothing to see for this banned user in the room in the token range
- self.assertIsNone(response_body["rooms"].get(room_id1))
diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py
deleted file mode 100644
index cb7638c5ba..0000000000
--- a/tests/rest/client/sliding_sync/test_sliding_sync.py
+++ /dev/null
@@ -1,974 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-import logging
-from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
-
-from typing_extensions import assert_never
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.api.constants import (
- AccountDataTypes,
- EventContentFields,
- EventTypes,
- RoomTypes,
-)
-from synapse.events import EventBase
-from synapse.rest.client import devices, login, receipts, room, sync
-from synapse.server import HomeServer
-from synapse.types import (
- JsonDict,
- RoomStreamToken,
- SlidingSyncStreamToken,
- StreamKeyType,
- StreamToken,
-)
-from synapse.util import Clock
-from synapse.util.stringutils import random_string
-
-from tests import unittest
-from tests.server import TimedOutException
-
-logger = logging.getLogger(__name__)
-
-
-class SlidingSyncBase(unittest.HomeserverTestCase):
- """Base class for sliding sync test cases"""
-
- sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def do_sync(
- self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str
- ) -> Tuple[JsonDict, str]:
- """Do a sliding sync request with given body.
-
- Asserts the request was successful.
-
- Attributes:
- sync_body: The full request body to use
- since: Optional since token
- tok: Access token to use
-
- Returns:
- A tuple of the response body and the `pos` field.
- """
-
- sync_path = self.sync_endpoint
- if since:
- sync_path += f"?pos={since}"
-
- channel = self.make_request(
- method="POST",
- path=sync_path,
- content=sync_body,
- access_token=tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- return channel.json_body, channel.json_body["pos"]
-
- def _assertRequiredStateIncludes(
- self,
- actual_required_state: Any,
- expected_state_events: Iterable[EventBase],
- exact: bool = False,
- ) -> None:
- """
- Wrapper around `assertIncludes` to give slightly better looking diff error
- messages that include some context "$event_id (type, state_key)".
-
- Args:
- actual_required_state: The "required_state" of a room from a Sliding Sync
- request response.
- expected_state_events: The expected state events to be included in the
- `actual_required_state`.
- exact: Whether the actual state should be exactly equal to the expected
- state (no extras).
- """
-
- assert isinstance(actual_required_state, list)
- for event in actual_required_state:
- assert isinstance(event, dict)
-
- self.assertIncludes(
- {
- f'{event["event_id"]} ("{event["type"]}", "{event["state_key"]}")'
- for event in actual_required_state
- },
- {
- f'{event.event_id} ("{event.type}", "{event.state_key}")'
- for event in expected_state_events
- },
- exact=exact,
- # Message to help understand the diff in context
- message=str(actual_required_state),
- )
-
- def _bump_notifier_wait_for_events(
- self,
- user_id: str,
- wake_stream_key: Literal[
- StreamKeyType.ACCOUNT_DATA,
- StreamKeyType.PRESENCE,
- ],
- ) -> None:
- """
- Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
- Sync results.
-
- Args:
- user_id: The user ID to wake up the notifier for
- wake_stream_key: The stream key to wake up. This will create an actual new
- entity in that stream so it's best to choose one that won't affect the
- Sliding Sync results you're testing for. In other words, if your testing
- account data, choose `StreamKeyType.PRESENCE` instead. We support two
- possible stream keys because you're probably testing one or the other so
- one is always a "safe" option.
- """
- # We're expecting some new activity from this point onwards
- from_token = self.hs.get_event_sources().get_current_token()
-
- triggered_notifier_wait_for_events = False
-
- async def _on_new_acivity(
- before_token: StreamToken, after_token: StreamToken
- ) -> bool:
- nonlocal triggered_notifier_wait_for_events
- triggered_notifier_wait_for_events = True
- return True
-
- notifier = self.hs.get_notifier()
-
- # Listen for some new activity for the user. We're just trying to confirm that
- # our bump below actually does what we think it does (triggers new activity for
- # the user).
- result_awaitable = notifier.wait_for_events(
- user_id,
- 1000,
- _on_new_acivity,
- from_token=from_token,
- )
-
- # Update the account data or presence so that `notifier.wait_for_events(...)`
- # wakes up. We chose these two options because they're least likely to show up
- # in the Sliding Sync response so it won't affect whether we have results.
- if wake_stream_key == StreamKeyType.ACCOUNT_DATA:
- self.get_success(
- self.hs.get_account_data_handler().add_account_data_for_user(
- user_id,
- "org.matrix.foobarbaz",
- {"foo": "bar"},
- )
- )
- elif wake_stream_key == StreamKeyType.PRESENCE:
- sending_user_id = self.register_user(
- "user_bump_notifier_wait_for_events_" + random_string(10), "pass"
- )
- sending_user_tok = self.login(sending_user_id, "pass")
- test_msg = {"foo": "bar"}
- chan = self.make_request(
- "PUT",
- "/_matrix/client/r0/sendToDevice/m.test/1234",
- content={"messages": {user_id: {"d1": test_msg}}},
- access_token=sending_user_tok,
- )
- self.assertEqual(chan.code, 200, chan.result)
- else:
- assert_never(wake_stream_key)
-
- # Wait for our notifier result
- self.get_success(result_awaitable)
-
- if not triggered_notifier_wait_for_events:
- raise AssertionError(
- "Expected `notifier.wait_for_events(...)` to be triggered"
- )
-
-
-class SlidingSyncTestCase(SlidingSyncBase):
- """
- Tests regarding MSC3575 Sliding Sync `/sync` endpoint.
-
- Please put tests in more specific test files if applicable. This test class is meant
- for generic behavior of the endpoint.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- receipts.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
- self.storage_controllers = hs.get_storage_controllers()
- self.account_data_handler = hs.get_account_data_handler()
-
- def _add_new_dm_to_global_account_data(
- self, source_user_id: str, target_user_id: str, target_room_id: str
- ) -> None:
- """
- Helper to handle inserting a new DM for the source user into global account data
- (handles all of the list merging).
-
- Args:
- source_user_id: The user ID of the DM mapping we're going to update
- target_user_id: User ID of the person the DM is with
- target_room_id: Room ID of the DM
- """
-
- # Get the current DM map
- existing_dm_map = self.get_success(
- self.store.get_global_account_data_by_type_for_user(
- source_user_id, AccountDataTypes.DIRECT
- )
- )
- # Scrutinize the account data since it has no concrete type. We're just copying
- # everything into a known type. It should be a mapping from user ID to a list of
- # room IDs. Ignore anything else.
- new_dm_map: Dict[str, List[str]] = {}
- if isinstance(existing_dm_map, dict):
- for user_id, room_ids in existing_dm_map.items():
- if isinstance(user_id, str) and isinstance(room_ids, list):
- for room_id in room_ids:
- if isinstance(room_id, str):
- new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
- room_id
- ]
-
- # Add the new DM to the map
- new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
- target_room_id
- ]
- # Save the DM map to global account data
- self.get_success(
- self.store.add_account_data_for_user(
- source_user_id,
- AccountDataTypes.DIRECT,
- new_dm_map,
- )
- )
-
- def _create_dm_room(
- self,
- inviter_user_id: str,
- inviter_tok: str,
- invitee_user_id: str,
- invitee_tok: str,
- should_join_room: bool = True,
- ) -> str:
- """
- Helper to create a DM room as the "inviter" and invite the "invitee" user to the
- room. The "invitee" user also will join the room. The `m.direct` account data
- will be set for both users.
- """
-
- # Create a room and send an invite the other user
- room_id = self.helper.create_room_as(
- inviter_user_id,
- is_public=False,
- tok=inviter_tok,
- )
- self.helper.invite(
- room_id,
- src=inviter_user_id,
- targ=invitee_user_id,
- tok=inviter_tok,
- extra_data={"is_direct": True},
- )
- if should_join_room:
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
-
- # Mimic the client setting the room as a direct message in the global account
- # data for both users.
- self._add_new_dm_to_global_account_data(
- invitee_user_id, inviter_user_id, room_id
- )
- self._add_new_dm_to_global_account_data(
- inviter_user_id, invitee_user_id, room_id
- )
-
- return room_id
-
- def test_sync_list(self) -> None:
- """
- Test that room IDs show up in the Sliding Sync `lists`
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
- response_body["lists"].keys(),
- )
-
- # Make sure the list includes the room we are joined to
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [room_id],
- }
- ],
- response_body["lists"]["foo-list"],
- )
-
- def test_wait_for_sync_token(self) -> None:
- """
- Test that worker will wait until it catches up to the given token
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a future token that will cause us to wait. Since we never send a new
- # event to reach that future stream_ordering, the worker will wait until the
- # full timeout.
- stream_id_gen = self.store.get_events_stream_id_generator()
- stream_id = self.get_success(stream_id_gen.get_next().__aenter__())
- current_token = self.event_sources.get_current_token()
- future_position_token = current_token.copy_and_replace(
- StreamKeyType.ROOM,
- RoomStreamToken(stream=stream_id),
- )
-
- future_position_token_serialized = self.get_success(
- SlidingSyncStreamToken(future_position_token, 0).to_string(self.store)
- )
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?pos={future_position_token_serialized}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 10 seconds to make `notifier.wait_for_stream_token(from_token)`
- # timeout
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=9900)
- channel.await_result(timeout_ms=200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # We expect the next `pos` in the result to be the same as what we requested
- # with because we weren't able to find anything new yet.
- self.assertEqual(channel.json_body["pos"], future_position_token_serialized)
-
- def test_wait_for_new_data(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive.
-
- (Only applies to incremental syncs with a `timeout` specified)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Bump the room with new events to trigger new results
- event_response1 = self.helper.send(
- room_id, "new activity in room", tok=user1_tok
- )
- # Should respond before the 10 second timeout
- channel.await_result(timeout_ms=3000)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check to make sure the new event is returned
- self.assertEqual(
- [
- event["event_id"]
- for event in channel.json_body["rooms"][room_id]["timeline"]
- ],
- [
- event_response1["event_id"],
- ],
- channel.json_body["rooms"][room_id]["timeline"],
- )
-
- def test_wait_for_new_data_timeout(self) -> None:
- """
- Test to make sure that the Sliding Sync request waits for new data to arrive but
- no data ever arrives so we timeout. We're also making sure that the default data
- doesn't trigger a false-positive for new data.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?timeout=10000&pos={from_token}",
- content=sync_body,
- access_token=user1_tok,
- await_result=False,
- )
- # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=5000)
- # Wake-up `notifier.wait_for_events(...)` that will cause us test
- # `SlidingSyncResult.__bool__` for new results.
- self._bump_notifier_wait_for_events(
- user1_id, wake_stream_key=StreamKeyType.ACCOUNT_DATA
- )
- # Block for a little bit more to ensure we don't see any new results.
- with self.assertRaises(TimedOutException):
- channel.await_result(timeout_ms=4000)
- # Wait for the sync to complete (wait for the rest of the 10 second timeout,
- # 5000 + 4000 + 1200 > 10000)
- channel.await_result(timeout_ms=1200)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # There should be no room sent down.
- self.assertFalse(channel.json_body["rooms"])
-
- def test_filter_list(self) -> None:
- """
- Test that filters apply to `lists`
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a DM room
- joined_dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- should_join_room=True,
- )
- invited_dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- should_join_room=False,
- )
-
- # Create a normal room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- # Create a room that user1 is invited to
- invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- # Absense of filters does not imply "False" values
- "all": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {},
- },
- # Test single truthy filter
- "dms": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": True},
- },
- # Test single falsy filter
- "non-dms": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": False},
- },
- # Test how multiple filters should stack (AND'd together)
- "room-invites": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": False, "is_invite": True},
- },
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["all", "dms", "non-dms", "room-invites"],
- response_body["lists"].keys(),
- )
-
- # Make sure the lists have the correct rooms
- self.assertListEqual(
- list(response_body["lists"]["all"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [
- invite_room_id,
- room_id,
- invited_dm_room_id,
- joined_dm_room_id,
- ],
- }
- ],
- list(response_body["lists"]["all"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["dms"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invited_dm_room_id, joined_dm_room_id],
- }
- ],
- list(response_body["lists"]["dms"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["non-dms"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invite_room_id, room_id],
- }
- ],
- list(response_body["lists"]["non-dms"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["room-invites"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invite_room_id],
- }
- ],
- list(response_body["lists"]["room-invites"]),
- )
-
- # Ensure DM's are correctly marked
- self.assertDictEqual(
- {
- room_id: room.get("is_dm")
- for room_id, room in response_body["rooms"].items()
- },
- {
- invite_room_id: None,
- room_id: None,
- invited_dm_room_id: True,
- joined_dm_room_id: True,
- },
- )
-
- def test_filter_regardless_of_membership_server_left_room(self) -> None:
- """
- Test that filters apply to rooms regardless of membership. We're also
- compounding the problem by having all of the local users leave the room causing
- our server to leave the room.
-
- We want to make sure that if someone is filtering rooms, and leaves, you still
- get that final update down sync that you left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted space room
- space_room_id = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- self.helper.send_state(
- space_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user2_tok,
- )
- self.helper.join(space_room_id, user1_id, tok=user1_tok)
-
- # Make an initial Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint,
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
- },
- }
- },
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- from_token = channel.json_body["pos"]
-
- # Make sure the response has the lists we requested
- self.assertListEqual(
- list(channel.json_body["lists"].keys()),
- ["all-list", "foo-list"],
- channel.json_body["lists"].keys(),
- )
-
- # Make sure the lists have the correct rooms
- self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id, room_id],
- }
- ],
- )
- self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id],
- }
- ],
- )
-
- # Everyone leaves the encrypted space room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
- self.helper.leave(space_room_id, user2_id, tok=user2_tok)
-
- # Make an incremental Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?pos={from_token}",
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
- },
- }
- },
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Make sure the lists have the correct rooms even though we `newly_left`
- self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id, room_id],
- }
- ],
- )
- self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id],
- }
- ],
- )
-
- def test_sort_list(self) -> None:
- """
- Test that the `lists` are sorted by `stream_ordering`
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
-
- # Activity that will order the rooms
- self.helper.send(room_id3, "activity in room3", tok=user1_tok)
- self.helper.send(room_id1, "activity in room1", tok=user1_tok)
- self.helper.send(room_id2, "activity in room2", tok=user1_tok)
-
- # Make the Sliding Sync request
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
- response_body["lists"].keys(),
- )
-
- # Make sure the list is sorted in the way we expect
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [room_id2, room_id1, room_id3],
- }
- ],
- response_body["lists"]["foo-list"],
- )
-
- def test_sliced_windows(self) -> None:
- """
- Test that the `lists` `ranges` are sliced correctly. Both sides of each range
- are inclusive.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- _room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- room_id3 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
-
- # Make the Sliding Sync request for a single room
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 0]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
- response_body["lists"].keys(),
- )
- # Make sure the list is sorted in the way we expect
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 0],
- "room_ids": [room_id3],
- }
- ],
- response_body["lists"]["foo-list"],
- )
-
- # Make the Sliding Sync request for the first two rooms
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 1,
- }
- }
- }
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
- response_body["lists"].keys(),
- )
- # Make sure the list is sorted in the way we expect
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id3, room_id2],
- }
- ],
- response_body["lists"]["foo-list"],
- )
-
- def test_rooms_with_no_updates_do_not_come_down_incremental_sync(self) -> None:
- """
- Test that rooms with no updates are returned in subsequent incremental
- syncs.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
-
- _, from_token = self.do_sync(sync_body, tok=user1_tok)
-
- # Make the incremental Sliding Sync request
- response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
-
- # Nothing has happened in the room, so the room should not come down
- # /sync.
- self.assertIsNone(response_body["rooms"].get(room_id1))
-
- def test_empty_initial_room_comes_down_sync(self) -> None:
- """
- Test that rooms come down /sync even with empty required state and
- timeline limit in initial sync.
- """
-
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- sync_body = {
- "lists": {
- "foo-list": {
- "ranges": [[0, 1]],
- "required_state": [],
- "timeline_limit": 0,
- }
- }
- }
-
- # Make the Sliding Sync request
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a85ea994de..f1e4bdea76 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -427,23 +426,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
text = None
for part in mail.walk():
if part.get_content_type() == "text/plain":
- text = part.get_payload(decode=True)
- if text is not None:
- # According to the logic table in `get_payload`, we know that
- # the result of `get_payload` will be `bytes`, but mypy doesn't
- # know this and complains. Thus, we assert the type.
- assert isinstance(text, bytes)
- text = text.decode("UTF-8")
-
+ text = part.get_payload(decode=True).decode("UTF-8")
break
if not text:
self.fail("Could not find text portion of email to parse")
- # `text` must be a `str`, after being decoded and determined just above
- # to not be `None` or an empty `str`.
- assert isinstance(text, str)
-
+ assert text is not None
match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email"
@@ -1219,23 +1208,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
text = None
for part in mail.walk():
if part.get_content_type() == "text/plain":
- text = part.get_payload(decode=True)
- if text is not None:
- # According to the logic table in `get_payload`, we know that
- # the result of `get_payload` will be `bytes`, but mypy doesn't
- # know this and complains. Thus, we assert the type.
- assert isinstance(text, bytes)
- text = text.decode("UTF-8")
-
+ text = part.get_payload(decode=True).decode("UTF-8")
break
if not text:
self.fail("Could not find text portion of email to parse")
- # `text` must be a `str`, after being decoded and determined just above
- # to not be `None` or an empty `str`.
- assert isinstance(text, str)
-
+ assert text is not None
match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email"
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
index be6d7af2fc..ce505eef62 100644
--- a/tests/rest/client/test_account_data.py
+++ b/tests/rest/client/test_account_data.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 0b5daf4bb4..8bc36209e5 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020-2021 The Matrix.org Foundation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index a3ed12a38f..f1cfe0df91 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -24,8 +23,8 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
-from synapse.rest import admin, devices, sync
-from synapse.rest.client import keys, login, register
+from synapse.rest import admin, devices, room, sync
+from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -33,6 +32,146 @@ from synapse.util import Clock
from tests import unittest
+class DeviceListsTestCase(unittest.HomeserverTestCase):
+ """Tests regarding device list changes."""
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ account.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def test_receiving_local_device_list_changes(self) -> None:
+ """Tests that a local users that share a room receive each other's device list
+ changes.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # Create a room for them to coexist peacefully in
+ new_room_id = self.helper.create_room_as(
+ alice_user_id, is_public=True, tok=alice_access_token
+ )
+ self.assertIsNotNone(new_room_id)
+
+ # Have Bob join the room
+ self.helper.invite(
+ new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+ )
+ self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+ # Now have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ "/sync",
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch_token}&timeout=30000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ # Check that bob's incremental sync contains the updated device list.
+ # If not, the client would only receive the device list update on the
+ # *next* sync.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+ def test_not_receiving_local_device_list_changes(self) -> None:
+ """Tests a local users DO NOT receive device updates from each other if they do not
+ share a room.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # These users do not share a room. They are lonely.
+
+ # Have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ "/sync",
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch_token}&timeout=1000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ # Check that bob's incremental sync does not contain the updated device list.
+ bob_sync_channel.await_result()
+ self.assertEqual(
+ bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
+ )
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertNotIn(
+ alice_user_id, changed_device_lists, bob_sync_channel.json_body
+ )
+
+
class DevicesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 06f1c1b234..c456b24839 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 9cfc6b224f..b3b108dd5e 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -72,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
- self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
+ self.hs.is_mine = lambda target_user: False # type: ignore[method-assign]
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 8bbd109092..e99160c5ac 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -155,6 +154,71 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
}
def test_device_signing_with_uia(self) -> None:
+ """Device signing key upload requires UIA."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ content["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config({"ui_auth": {"session_timeout": "15m"}})
+ def test_device_signing_with_uia_session_timeout(self) -> None:
+ """Device signing key upload requires UIA buy passes with grace period."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config(
+ {
+ "experimental_features": {"msc3967_enabled": True},
+ "ui_auth": {"session_timeout": "15s"},
+ }
+ )
+ def test_device_signing_with_msc3967(self) -> None:
+ """Device signing key follows MSC3967 behaviour when enabled."""
password = "wonderland"
device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 2b1e44381b..42610308d5 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -20,17 +19,7 @@
#
import time
import urllib.parse
-from typing import (
- Any,
- BinaryIO,
- Callable,
- Collection,
- Dict,
- List,
- Optional,
- Tuple,
- Union,
-)
+from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -44,9 +33,8 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi
-from synapse.rest.client import account, devices, login, logout, profile, register
+from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
@@ -59,7 +47,6 @@ from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
-from tests.test_utils.oidc import FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
@@ -189,6 +176,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# rc_login dict here, we need to set this manually as well
"account": {"per_second": 10000, "burst_count": 10000},
},
+ "experimental_features": {"msc4041_enabled": True},
}
)
def test_POST_ratelimiting_per_address(self) -> None:
@@ -240,6 +228,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# rc_login dict here, we need to set this manually as well
"address": {"per_second": 10000, "burst_count": 10000},
},
+ "experimental_features": {"msc4041_enabled": True},
}
)
def test_POST_ratelimiting_per_account(self) -> None:
@@ -288,6 +277,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"address": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 0.17, "burst_count": 5},
},
+ "experimental_features": {"msc4041_enabled": True},
}
)
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
@@ -969,8 +959,9 @@ class CASTestCase(unittest.HomeserverTestCase):
# Test that the response is HTML.
self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
- for header in channel.headers.getRawHeaders("Content-Type", []):
- content_type_header_value = header
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type_header_value = header[1].decode("utf8")
self.assertTrue(content_type_header_value.startswith("text/html"))
@@ -1432,19 +1423,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
class UsernamePickerTestCase(HomeserverTestCase):
"""Tests for the username picker flow of SSO login"""
- servlets = [
- login.register_servlets,
- profile.register_servlets,
- account.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.http_client = Mock(spec=["get_file"])
- self.http_client.get_file.side_effect = mock_get_file
- hs = self.setup_test_homeserver(
- proxied_blocklisted_http_client=self.http_client
- )
- return hs
+ servlets = [login.register_servlets]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
@@ -1453,11 +1432,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG)
config["oidc_config"]["user_mapping_provider"] = {
- "config": {
- "display_name_template": "{{ user.displayname }}",
- "email_template": "{{ user.email }}",
- "picture_template": "{{ user.picture }}",
- }
+ "config": {"display_name_template": "{{ user.displayname }}"}
}
# whitelist this client URI so we redirect straight to it rather than
@@ -1470,22 +1445,15 @@ class UsernamePickerTestCase(HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs))
return d
- def proceed_to_username_picker_page(
- self,
- fake_oidc_server: FakeOidcServer,
- displayname: str,
- email: str,
- picture: str,
- ) -> Tuple[str, str]:
+ def test_username_picker(self) -> None:
+ """Test the happy path of a username picker flow."""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# do the start of the login flow
channel, _ = self.helper.auth_via_oidc(
fake_oidc_server,
- {
- "sub": "tester",
- "displayname": displayname,
- "picture": picture,
- "email": email,
- },
+ {"sub": "tester", "displayname": "Jonny"},
TEST_CLIENT_REDIRECT_URL,
)
@@ -1512,132 +1480,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
- self.assertEqual(session.display_name, displayname)
- self.assertEqual(session.emails, [email])
- self.assertEqual(session.avatar_url, picture)
+ self.assertEqual(session.display_name, "Jonny")
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
# the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
- return picker_url, session_id
-
- def test_username_picker_use_displayname_avatar_and_email(self) -> None:
- """Test the happy path of a username picker flow with using displayname, avatar and email."""
-
- fake_oidc_server = self.helper.fake_oidc_server()
-
- mxid = "@bobby:test"
- displayname = "Jonny"
- email = "bobby@test.com"
- picture = "mxc://test/avatar_url"
-
- picker_url, session_id = self.proceed_to_username_picker_page(
- fake_oidc_server, displayname, email, picture
- )
-
- # Now, submit a username to the username picker, which should serve a redirect
- # to the completion page.
- # Also specify that we should use the provided displayname, avatar and email.
- content = urlencode(
- {
- b"username": b"bobby",
- b"use_display_name": b"true",
- b"use_avatar": b"true",
- b"use_email": email,
- }
- ).encode("utf8")
- chan = self.make_request(
- "POST",
- path=picker_url,
- content=content,
- content_is_form=True,
- custom_headers=[
- ("Cookie", "username_mapping_session=" + session_id),
- # old versions of twisted don't do form-parsing without a valid
- # content-length header.
- ("Content-Length", str(len(content))),
- ],
- )
- self.assertEqual(chan.code, 302, chan.result)
- location_headers = chan.headers.getRawHeaders("Location")
- assert location_headers
-
- # send a request to the completion page, which should 302 to the client redirectUrl
- chan = self.make_request(
- "GET",
- path=location_headers[0],
- custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
- )
- self.assertEqual(chan.code, 302, chan.result)
- location_headers = chan.headers.getRawHeaders("Location")
- assert location_headers
-
- # ensure that the returned location matches the requested redirect URL
- path, query = location_headers[0].split("?", 1)
- self.assertEqual(path, "https://x")
-
- # it will have url-encoded the params properly, so we'll have to parse them
- params = urllib.parse.parse_qsl(
- query, keep_blank_values=True, strict_parsing=True, errors="strict"
- )
- self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
- self.assertEqual(params[2][0], "loginToken")
-
- # fish the login token out of the returned redirect uri
- login_token = params[2][1]
-
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token, mxid, and device id.
- chan = self.make_request(
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- self.assertEqual(chan.code, 200, chan.result)
- self.assertEqual(chan.json_body["user_id"], mxid)
-
- # ensure the displayname and avatar from the OIDC response have been configured for the user.
- channel = self.make_request(
- "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertIn("mxc://test", channel.json_body["avatar_url"])
- self.assertEqual(displayname, channel.json_body["displayname"])
-
- # ensure the email from the OIDC response has been configured for the user.
- channel = self.make_request(
- "GET", "/account/3pid", access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertEqual(email, channel.json_body["threepids"][0]["address"])
-
- def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
- """Test the happy path of a username picker flow without using displayname, avatar or email."""
-
- fake_oidc_server = self.helper.fake_oidc_server()
-
- mxid = "@bobby:test"
- displayname = "Jonny"
- email = "bobby@test.com"
- picture = "mxc://test/avatar_url"
- username = "bobby"
-
- picker_url, session_id = self.proceed_to_username_picker_page(
- fake_oidc_server, displayname, email, picture
- )
-
# Now, submit a username to the username picker, which should serve a redirect
- # to the completion page.
- # Also specify that we should not use the provided displayname, avatar or email.
- content = urlencode(
- {
- b"username": username,
- b"use_display_name": b"false",
- b"use_avatar": b"false",
- }
- ).encode("utf8")
+ # to the completion page
+ content = urlencode({b"username": b"bobby"}).encode("utf8")
chan = self.make_request(
"POST",
path=picker_url,
@@ -1686,29 +1538,4 @@ class UsernamePickerTestCase(HomeserverTestCase):
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
- self.assertEqual(chan.json_body["user_id"], mxid)
-
- # ensure the displayname and avatar from the OIDC response have not been configured for the user.
- channel = self.make_request(
- "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertNotIn("avatar_url", channel.json_body)
- self.assertEqual(username, channel.json_body["displayname"])
-
- # ensure the email from the OIDC response has not been configured for the user.
- channel = self.make_request(
- "GET", "/account/3pid", access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertListEqual([], channel.json_body["threepids"])
-
-
-async def mock_get_file(
- url: str,
- output_stream: BinaryIO,
- max_size: Optional[int] = None,
- headers: Optional[RawHeaders] = None,
- is_allowed_content_type: Optional[Callable[[str], bool]] = None,
-) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
- return 0, {b"Content-Type": [b"image/png"]}, "", 200
+ self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index fbacf9d869..2913679c8b 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
deleted file mode 100644
index 30b6d31d0a..0000000000
--- a/tests/rest/client/test_media.py
+++ /dev/null
@@ -1,2677 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import base64
-import io
-import json
-import os
-import re
-import shutil
-from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
-from unittest.mock import MagicMock, Mock, patch
-from urllib import parse
-from urllib.parse import quote, urlencode
-
-from parameterized import parameterized, parameterized_class
-from PIL import Image as Image
-from typing_extensions import ClassVar
-
-from twisted.internet import defer
-from twisted.internet._resolver import HostResolution
-from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.defer import Deferred
-from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IAddress, IResolutionReceiver
-from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
-from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
-from twisted.web.resource import Resource
-
-from synapse.api.errors import HttpResponseException
-from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.oembed import OEmbedEndpointConfig
-from synapse.http.client import MultipartResponse
-from synapse.http.types import QueryParams
-from synapse.logging.context import make_deferred_yieldable
-from synapse.media._base import FileInfo, ThumbnailInfo
-from synapse.media.thumbnailer import ThumbnailProvider
-from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
-from synapse.rest import admin
-from synapse.rest.client import login, media
-from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
-from synapse.util import Clock
-from synapse.util.stringutils import parse_and_validate_mxc_uri
-
-from tests import unittest
-from tests.media.test_media_storage import (
- SVG,
- TestImage,
- empty_file,
- small_lossless_webp,
- small_png,
- small_png_with_transparency,
-)
-from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock
-from tests.test_utils import SMALL_PNG
-from tests.unittest import override_config
-
-try:
- import lxml
-except ImportError:
- lxml = None # type: ignore[assignment]
-
-
-class MediaDomainBlockingTests(unittest.HomeserverTestCase):
- remote_media_id = "doesnotmatter"
- remote_server_name = "evil.com"
- servlets = [
- media.register_servlets,
- admin.register_servlets,
- login.register_servlets,
- ]
-
- def make_homeserver(
- self, reactor: ThreadedMemoryReactorClock, clock: Clock
- ) -> HomeServer:
- config = self.default_config()
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
- config["media_store_path"] = self.media_store_path
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
-
- config["media_storage_providers"] = [provider_config]
-
- return self.setup_test_homeserver(config=config)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- # Inject a piece of media. We'll use this to ensure we're returning a sane
- # response when we're not supposed to block it, distinguishing a media block
- # from a regular 404.
- file_id = "abcdefg12345"
- file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
-
- media_storage = hs.get_media_repository().media_storage
-
- ctx = media_storage.store_into_file(file_info)
- (f, fname) = self.get_success(ctx.__aenter__())
- f.write(SMALL_PNG)
- self.get_success(ctx.__aexit__(None, None, None))
-
- self.get_success(
- self.store.store_cached_remote_media(
- origin=self.remote_server_name,
- media_id=self.remote_media_id,
- media_type="image/png",
- media_length=1,
- time_now_ms=clock.time_msec(),
- upload_name="test.png",
- filesystem_id=file_id,
- )
- )
- self.register_user("user", "password")
- self.tok = self.login("user", "password")
-
- @override_config(
- {
- # Disable downloads from the domain we'll be trying to download from.
- # Should result in a 404.
- "prevent_media_downloads_from": ["evil.com"],
- "dynamic_thumbnails": True,
- }
- )
- def test_cannot_download_blocked_media_thumbnail(self) -> None:
- """
- Same test as test_cannot_download_blocked_media but for thumbnails.
- """
- response = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
- shorthand=False,
- content={"width": 100, "height": 100},
- access_token=self.tok,
- )
- self.assertEqual(response.code, 404)
-
- @override_config(
- {
- # Disable downloads from a domain we won't be requesting downloads from.
- # This proves we haven't broken anything.
- "prevent_media_downloads_from": ["not-listed.com"],
- "dynamic_thumbnails": True,
- }
- )
- def test_remote_media_thumbnail_normally_unblocked(self) -> None:
- """
- Same test as test_remote_media_normally_unblocked but for thumbnails.
- """
- response = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(response.code, 200)
-
-
-class URLPreviewTests(unittest.HomeserverTestCase):
- if not lxml:
- skip = "url preview feature requires lxml"
-
- servlets = [media.register_servlets]
- hijack_auth = True
- user_id = "@test:user"
- end_content = (
- b"<html><head>"
- b'<meta property="og:title" content="~matrix~" />'
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
- config["url_preview_enabled"] = True
- config["max_spider_size"] = 9999999
- config["url_preview_ip_range_blacklist"] = (
- "192.168.1.1",
- "1.0.0.0/8",
- "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
- "2001:800::/21",
- )
- config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
- config["url_preview_accept_language"] = [
- "en-UK",
- "en-US;q=0.9",
- "fr;q=0.8",
- "*;q=0.7",
- ]
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
- config["media_store_path"] = self.media_store_path
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
-
- config["media_storage_providers"] = [provider_config]
-
- hs = self.setup_test_homeserver(config=config)
-
- # After the hs is created, modify the parsed oEmbed config (to avoid
- # messing with files).
- #
- # Note that HTTP URLs are used to avoid having to deal with TLS in tests.
- hs.config.oembed.oembed_patterns = [
- OEmbedEndpointConfig(
- api_endpoint="http://publish.twitter.com/oembed",
- url_patterns=[
- re.compile(r"http://twitter\.com/.+/status/.+"),
- ],
- formats=None,
- ),
- OEmbedEndpointConfig(
- api_endpoint="http://www.hulu.com/api/oembed.{format}",
- url_patterns=[
- re.compile(r"http://www\.hulu\.com/watch/.+"),
- ],
- formats=["json"],
- ),
- ]
-
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.media_repo = hs.get_media_repository()
- assert self.media_repo.url_previewer is not None
- self.url_previewer = self.media_repo.url_previewer
-
- self.lookups: Dict[str, Any] = {}
-
- class Resolver:
- def resolveHostName(
- _self,
- resolutionReceiver: IResolutionReceiver,
- hostName: str,
- portNumber: int = 0,
- addressTypes: Optional[Sequence[Type[IAddress]]] = None,
- transportSemantics: str = "TCP",
- ) -> IResolutionReceiver:
- resolution = HostResolution(hostName)
- resolutionReceiver.resolutionBegan(resolution)
- if hostName not in self.lookups:
- raise DNSLookupError("OH NO")
-
- for i in self.lookups[hostName]:
- resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber))
- resolutionReceiver.resolutionComplete()
- return resolutionReceiver
-
- self.reactor.nameResolver = Resolver() # type: ignore[assignment]
-
- def _assert_small_png(self, json_body: JsonDict) -> None:
- """Assert properties from the SMALL_PNG test image."""
- self.assertTrue(json_body["og:image"].startswith("mxc://"))
- self.assertEqual(json_body["og:image:height"], 1)
- self.assertEqual(json_body["og:image:width"], 1)
- self.assertEqual(json_body["og:image:type"], "image/png")
- self.assertEqual(json_body["matrix:image:size"], 67)
-
- def test_cache_returns_correct_type(self) -> None:
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
- % (len(self.end_content),)
- + self.end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
- )
-
- # Check the cache returns the correct response
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- )
-
- # Check the cache response has the same content
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
- )
-
- # Clear the in-memory cache
- self.assertIn("http://matrix.org", self.url_previewer._cache)
- self.url_previewer._cache.pop("http://matrix.org")
- self.assertNotIn("http://matrix.org", self.url_previewer._cache)
-
- # Check the database cache returns the correct response
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- )
-
- # Check the cache response has the same content
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
- )
-
- def test_non_ascii_preview_httpequiv(self) -> None:
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- end_content = (
- b"<html><head>"
- b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
- b'<meta property="og:title" content="\xe4\xea\xe0" />'
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
-
- def test_video_rejected(self) -> None:
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- end_content = b"anything"
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: video/mp4\r\n\r\n"
- )
- % (len(end_content))
- + end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "Requested file's content type not allowed for this operation: video/mp4",
- },
- )
-
- def test_audio_rejected(self) -> None:
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- end_content = b"anything"
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: audio/aac\r\n\r\n"
- )
- % (len(end_content))
- + end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "Requested file's content type not allowed for this operation: audio/aac",
- },
- )
-
- def test_non_ascii_preview_content_type(self) -> None:
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- end_content = (
- b"<html><head>"
- b'<meta property="og:title" content="\xe4\xea\xe0" />'
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
-
- def test_overlong_title(self) -> None:
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- end_content = (
- b"<html><head>"
- b"<title>" + b"x" * 2000 + b"</title>"
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- res = channel.json_body
- # We should only see the `og:description` field, as `title` is too long and should be stripped out
- self.assertCountEqual(["og:description"], res.keys())
-
- def test_ipaddr(self) -> None:
- """
- IP addresses can be previewed directly.
- """
- self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
- % (len(self.end_content),)
- + self.end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
- )
-
- def test_blocked_ip_specific(self) -> None:
- """
- Blocked IP addresses, found via DNS, are not spidered.
- """
- self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- )
-
- # No requests made.
- self.assertEqual(len(self.reactor.tcpClients), 0)
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "DNS resolution failure during URL preview generation",
- },
- )
-
- def test_blocked_ip_range(self) -> None:
- """
- Blocked IP ranges, IPs found over DNS, are not spidered.
- """
- self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- )
-
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "DNS resolution failure during URL preview generation",
- },
- )
-
- def test_blocked_ip_specific_direct(self) -> None:
- """
- Blocked IP addresses, accessed directly, are not spidered.
- """
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://192.168.1.1",
- shorthand=False,
- )
-
- # No requests made.
- self.assertEqual(len(self.reactor.tcpClients), 0)
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "IP address blocked"},
- )
- self.assertEqual(channel.code, 403)
-
- def test_blocked_ip_range_direct(self) -> None:
- """
- Blocked IP ranges, accessed directly, are not spidered.
- """
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://1.1.1.2",
- shorthand=False,
- )
-
- self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "IP address blocked"},
- )
-
- def test_blocked_ip_range_whitelisted_ip(self) -> None:
- """
- Blocked but then subsequently whitelisted IP addresses can be
- spidered.
- """
- self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
-
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
-
- client.dataReceived(
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
- % (len(self.end_content),)
- + self.end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
- )
-
- def test_blocked_ip_with_external_ip(self) -> None:
- """
- If a hostname resolves a blocked IP, even if there's a non-blocked one,
- it will be rejected.
- """
- # Hardcode the URL resolving to the IP we want.
- self.lookups["example.com"] = [
- (IPv4Address, "1.1.1.2"),
- (IPv4Address, "10.1.2.3"),
- ]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- )
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "DNS resolution failure during URL preview generation",
- },
- )
-
- def test_blocked_ipv6_specific(self) -> None:
- """
- Blocked IP addresses, found via DNS, are not spidered.
- """
- self.lookups["example.com"] = [
- (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
- ]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- )
-
- # No requests made.
- self.assertEqual(len(self.reactor.tcpClients), 0)
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "DNS resolution failure during URL preview generation",
- },
- )
-
- def test_blocked_ipv6_range(self) -> None:
- """
- Blocked IP ranges, IPs found over DNS, are not spidered.
- """
- self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- )
-
- self.assertEqual(channel.code, 502)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "DNS resolution failure during URL preview generation",
- },
- )
-
- def test_OPTIONS(self) -> None:
- """
- OPTIONS returns the OPTIONS.
- """
- channel = self.make_request(
- "OPTIONS",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- )
- self.assertEqual(channel.code, 204)
-
- def test_accept_language_config_option(self) -> None:
- """
- Accept-Language header is sent to the remote server
- """
- self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
-
- # Build and make a request to the server
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://example.com",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- # Extract Synapse's tcp client
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
-
- # Build a fake remote server to reply with
- server = AccumulatingProtocol()
-
- # Connect the two together
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
-
- # Tell Synapse that it has received some data from the remote server
- client.dataReceived(
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
- % (len(self.end_content),)
- + self.end_content
- )
-
- # Move the reactor along until we get a response on our original channel
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
- )
-
- # Check that the server received the Accept-Language header as part
- # of the request from Synapse
- self.assertIn(
- (
- b"Accept-Language: en-UK\r\n"
- b"Accept-Language: en-US;q=0.9\r\n"
- b"Accept-Language: fr;q=0.8\r\n"
- b"Accept-Language: *;q=0.7"
- ),
- server.data,
- )
-
- def test_image(self) -> None:
- """An image should be precached if mentioned in the HTML."""
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
-
- result = (
- b"""<html><body><img src="http://cdn.matrix.org/foo.png"></body></html>"""
- )
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- # Respond with the HTML.
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(result),)
- + result
- )
- self.pump()
-
- # Respond with the photo.
- client = self.reactor.tcpClients[1][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: image/png\r\n\r\n"
- )
- % (len(SMALL_PNG),)
- + SMALL_PNG
- )
- self.pump()
-
- # The image should be in the result.
- self.assertEqual(channel.code, 200)
- self._assert_small_png(channel.json_body)
-
- def test_nonexistent_image(self) -> None:
- """If the preview image doesn't exist, ensure some data is returned."""
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- result = (
- b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
- )
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(result),)
- + result
- )
-
- self.pump()
-
- # There should not be a second connection.
- self.assertEqual(len(self.reactor.tcpClients), 1)
-
- # The image should not be in the result.
- self.assertEqual(channel.code, 200)
- self.assertNotIn("og:image", channel.json_body)
-
- @unittest.override_config(
- {"url_preview_url_blacklist": [{"netloc": "cdn.matrix.org"}]}
- )
- def test_image_blocked(self) -> None:
- """If the preview image doesn't exist, ensure some data is returned."""
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.matrix.org"] = [(IPv4Address, "10.1.2.4")]
-
- result = (
- b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
- )
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(result),)
- + result
- )
- self.pump()
-
- # There should not be a second connection.
- self.assertEqual(len(self.reactor.tcpClients), 1)
-
- # The image should not be in the result.
- self.assertEqual(channel.code, 200)
- self.assertNotIn("og:image", channel.json_body)
-
- def test_oembed_failure(self) -> None:
- """If the autodiscovered oEmbed URL fails, ensure some data is returned."""
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- result = b"""
- <title>oEmbed Autodiscovery Fail</title>
- <link rel="alternate" type="application/json+oembed"
- href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json"
- title="matrixdotorg" />
- """
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(result),)
- + result
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
-
- # The image should not be in the result.
- self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail")
-
- def test_data_url(self) -> None:
- """
- Requesting to preview a data URL is not supported.
- """
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- data = base64.b64encode(SMALL_PNG).decode()
-
- query_params = urlencode(
- {
- "url": f'<html><head><img src="data:image/png;base64,{data}" /></head></html>'
- }
- )
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/preview_url?{query_params}",
- shorthand=False,
- )
- self.pump()
-
- self.assertEqual(channel.code, 500)
-
- def test_inline_data_url(self) -> None:
- """
- An inline image (as a data URL) should be parsed properly.
- """
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- data = base64.b64encode(SMALL_PNG)
-
- end_content = (
- b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
- ) % (data,)
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- self._assert_small_png(channel.json_body)
-
- def test_oembed_photo(self) -> None:
- """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "photo",
- "url": "http://cdn.twitter.com/matrixdotorg",
- }
- oembed_content = json.dumps(result).encode("utf-8")
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(oembed_content),)
- + oembed_content
- )
-
- self.pump()
-
- # Ensure a second request is made to the photo URL.
- client = self.reactor.tcpClients[1][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: image/png\r\n\r\n"
- )
- % (len(SMALL_PNG),)
- + SMALL_PNG
- )
-
- self.pump()
-
- # Ensure the URL is what was requested.
- self.assertIn(b"/matrixdotorg", server.data)
-
- self.assertEqual(channel.code, 200)
- body = channel.json_body
- self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
- self._assert_small_png(body)
-
- def test_oembed_rich(self) -> None:
- """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "rich",
- # Note that this provides the author, not the title.
- "author_name": "Alice",
- "html": "<div>Content Preview</div>",
- }
- end_content = json.dumps(result).encode("utf-8")
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
-
- self.pump()
-
- # Double check that the proper host is being connected to. (Note that
- # twitter.com can't be resolved so this is already implicitly checked.)
- self.assertIn(b"\r\nHost: publish.twitter.com\r\n", server.data)
-
- self.assertEqual(channel.code, 200)
- body = channel.json_body
- self.assertEqual(
- body,
- {
- "og:url": "http://twitter.com/matrixdotorg/status/12345",
- "og:title": "Alice",
- "og:description": "Content Preview",
- },
- )
-
- def test_oembed_format(self) -> None:
- """Test an oEmbed endpoint which requires the format in the URL."""
- self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "rich",
- "html": "<div>Content Preview</div>",
- }
- end_content = json.dumps(result).encode("utf-8")
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://www.hulu.com/watch/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
-
- self.pump()
-
- # The {format} should have been turned into json.
- self.assertIn(b"/api/oembed.json", server.data)
- # A URL parameter of format=json should be provided.
- self.assertIn(b"format=json", server.data)
-
- self.assertEqual(channel.code, 200)
- body = channel.json_body
- self.assertEqual(
- body,
- {
- "og:url": "http://www.hulu.com/watch/12345",
- "og:description": "Content Preview",
- },
- )
-
- @unittest.override_config(
- {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
- )
- def test_oembed_blocked(self) -> None:
- """The oEmbed URL should not be downloaded if the oEmbed URL is blocked."""
- self.lookups["twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(channel.code, 403, channel.result)
-
- def test_oembed_autodiscovery(self) -> None:
- """
- Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
- 1. Request a preview of a URL which is not known to the oEmbed code.
- 2. It returns HTML including a link to an oEmbed preview.
- 3. The oEmbed preview is requested and returns a URL for an image.
- 4. The image is requested for thumbnailing.
- """
- # This is a little cheesy in that we use the www subdomain (which isn't the
- # list of oEmbed patterns) to get "raw" HTML response.
- self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = b"""
- <link rel="alternate" type="application/json+oembed"
- href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
- title="matrixdotorg" />
- """
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(result),)
- + result
- )
- self.pump()
-
- # The oEmbed response.
- result2 = {
- "version": "1.0",
- "type": "photo",
- "url": "http://cdn.twitter.com/matrixdotorg",
- }
- oembed_content = json.dumps(result2).encode("utf-8")
-
- # Ensure a second request is made to the oEmbed URL.
- client = self.reactor.tcpClients[1][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(oembed_content),)
- + oembed_content
- )
- self.pump()
-
- # Ensure the URL is what was requested.
- self.assertIn(b"/oembed?", server.data)
-
- # Ensure a third request is made to the photo URL.
- client = self.reactor.tcpClients[2][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b"Content-Type: image/png\r\n\r\n"
- )
- % (len(SMALL_PNG),)
- + SMALL_PNG
- )
- self.pump()
-
- # Ensure the URL is what was requested.
- self.assertIn(b"/matrixdotorg", server.data)
-
- self.assertEqual(channel.code, 200)
- body = channel.json_body
- self.assertEqual(
- body["og:url"], "http://www.twitter.com/matrixdotorg/status/12345"
- )
- self._assert_small_png(body)
-
- @unittest.override_config(
- {"url_preview_url_blacklist": [{"netloc": "publish.twitter.com"}]}
- )
- def test_oembed_autodiscovery_blocked(self) -> None:
- """
- If the discovered oEmbed URL is blocked, it should be discarded.
- """
- # This is a little cheesy in that we use the www subdomain (which isn't the
- # list of oEmbed patterns) to get "raw" HTML response.
- self.lookups["www.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.4")]
-
- result = b"""
- <title>Test</title>
- <link rel="alternate" type="application/json+oembed"
- href="http://publish.twitter.com/oembed?url=http%3A%2F%2Fcdn.twitter.com%2Fmatrixdotorg%2Fstatus%2F12345&format=json"
- title="matrixdotorg" />
- """
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(result),)
- + result
- )
-
- self.pump()
-
- # Ensure there's no additional connections.
- self.assertEqual(len(self.reactor.tcpClients), 1)
-
- # Ensure the URL is what was requested.
- self.assertIn(b"\r\nHost: www.twitter.com\r\n", server.data)
-
- self.assertEqual(channel.code, 200)
- body = channel.json_body
- self.assertEqual(body["og:title"], "Test")
- self.assertNotIn("og:image", body)
-
- def _download_image(self) -> Tuple[str, str]:
- """Downloads an image into the URL cache.
- Returns:
- A (host, media_id) tuple representing the MXC URI of the image.
- """
- self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=http://cdn.twitter.com/matrixdotorg",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n"
- % (len(SMALL_PNG),)
- + SMALL_PNG
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
- body = channel.json_body
- mxc_uri = body["og:image"]
- host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri)
- self.assertIsNone(_port)
- return host, media_id
-
- def test_storage_providers_exclude_files(self) -> None:
- """Test that files are not stored in or fetched from storage providers."""
- host, media_id = self._download_image()
-
- rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id)
- media_store_path = os.path.join(self.media_store_path, rel_file_path)
- storage_provider_path = os.path.join(self.storage_path, rel_file_path)
-
- # Check storage
- self.assertTrue(os.path.isfile(media_store_path))
- self.assertFalse(
- os.path.isfile(storage_provider_path),
- "URL cache file was unexpectedly stored in a storage provider",
- )
-
- # Check fetching
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/download/{host}/{media_id}",
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(channel.code, 200)
-
- # Move cached file into the storage provider
- os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True)
- os.rename(media_store_path, storage_provider_path)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/download/{host}/{media_id}",
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(
- channel.code,
- 404,
- "URL cache file was unexpectedly retrieved from a storage provider",
- )
-
- def test_storage_providers_exclude_thumbnails(self) -> None:
- """Test that thumbnails are not stored in or fetched from storage providers."""
- host, media_id = self._download_image()
-
- rel_thumbnail_path = (
- self.media_repo.filepaths.url_cache_thumbnail_directory_rel(media_id)
- )
- media_store_thumbnail_path = os.path.join(
- self.media_store_path, rel_thumbnail_path
- )
- storage_provider_thumbnail_path = os.path.join(
- self.storage_path, rel_thumbnail_path
- )
-
- # Check storage
- self.assertTrue(os.path.isdir(media_store_thumbnail_path))
- self.assertFalse(
- os.path.isdir(storage_provider_thumbnail_path),
- "URL cache thumbnails were unexpectedly stored in a storage provider",
- )
-
- # Check fetching
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(channel.code, 200)
-
- # Remove the original, otherwise thumbnails will regenerate
- rel_file_path = self.media_repo.filepaths.url_cache_filepath_rel(media_id)
- media_store_path = os.path.join(self.media_store_path, rel_file_path)
- os.remove(media_store_path)
-
- # Move cached thumbnails into the storage provider
- os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True)
- os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(
- channel.code,
- 404,
- "URL cache thumbnail was unexpectedly retrieved from a storage provider",
- )
-
- def test_cache_expiry(self) -> None:
- """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
- _host, media_id = self._download_image()
-
- file_path = self.media_repo.filepaths.url_cache_filepath(media_id)
- file_dirs = self.media_repo.filepaths.url_cache_filepath_dirs_to_delete(
- media_id
- )
- thumbnail_dir = self.media_repo.filepaths.url_cache_thumbnail_directory(
- media_id
- )
- thumbnail_dirs = self.media_repo.filepaths.url_cache_thumbnail_dirs_to_delete(
- media_id
- )
-
- self.assertTrue(os.path.isfile(file_path))
- self.assertTrue(os.path.isdir(thumbnail_dir))
-
- self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1)
- self.get_success(self.url_previewer._expire_url_cache_data())
-
- for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
- self.assertFalse(
- os.path.exists(path),
- f"{os.path.relpath(path, self.media_store_path)} was not deleted",
- )
-
- @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]})
- def test_blocked_port(self) -> None:
- """Tests that blocking URLs with a port makes previewing such URLs
- fail with a 403 error and doesn't impact other previews.
- """
- self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
-
- bad_url = quote("http://matrix.org:8888/foo")
- good_url = quote("http://matrix.org/foo")
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=" + bad_url,
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(channel.code, 403, channel.result)
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=" + good_url,
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
- % (len(self.end_content),)
- + self.end_content
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
-
- @unittest.override_config(
- {"url_preview_url_blacklist": [{"netloc": "example.com"}]}
- )
- def test_blocked_url(self) -> None:
- """Tests that blocking URLs with a host makes previewing such URLs
- fail with a 403 error.
- """
- self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
-
- bad_url = quote("http://example.com/foo")
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/preview_url?url=" + bad_url,
- shorthand=False,
- await_result=False,
- )
- self.pump()
- self.assertEqual(channel.code, 403, channel.result)
-
-
-class MediaConfigTest(unittest.HomeserverTestCase):
- servlets = [
- media.register_servlets,
- admin.register_servlets,
- login.register_servlets,
- ]
-
- def make_homeserver(
- self, reactor: ThreadedMemoryReactorClock, clock: Clock
- ) -> HomeServer:
- config = self.default_config()
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
- config["media_store_path"] = self.media_store_path
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
-
- config["media_storage_providers"] = [provider_config]
-
- return self.setup_test_homeserver(config=config)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.register_user("user", "password")
- self.tok = self.login("user", "password")
-
- def test_media_config(self) -> None:
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/config",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
- )
-
-
-class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
- servlets = [
- media.register_servlets,
- login.register_servlets,
- admin.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
- config["media_store_path"] = self.media_store_path
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
-
- config["media_storage_providers"] = [provider_config]
-
- return self.setup_test_homeserver(config=config)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.repo = hs.get_media_repository()
- self.client = hs.get_federation_http_client()
- self.store = hs.get_datastores().main
- self.user = self.register_user("user", "pass")
- self.tok = self.login("user", "pass")
-
- # mock actually reading file body
- def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred:
- d: Deferred = defer.Deferred()
- d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
- return d
-
- def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred:
- d: Deferred = defer.Deferred()
- d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
- return d
-
- @patch(
- "synapse.http.matrixfederationclient.read_multipart_response",
- read_multipart_response_30MiB,
- )
- def test_download_ratelimit_default(self) -> None:
- """
- Test remote media download ratelimiting against default configuration - 500MB bucket
- and 87kb/second drain rate
- """
-
- # mock out actually sending the request, returns a 30MiB response
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = 31457280
- resp.headers = Headers(
- {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
- )
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- # first request should go through
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abc",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel.code == 200
-
- # next 15 should go through
- for i in range(15):
- channel2 = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/download/remote.org/abc{i}",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel2.code == 200
-
- # 17th will hit ratelimit
- channel3 = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcd",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel3.code == 429
-
- # however, a request from a different IP will go through
- channel4 = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcde",
- shorthand=False,
- client_ip="187.233.230.159",
- access_token=self.tok,
- )
- assert channel4.code == 200
-
- # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
- # 30MiB download is authorized - The last download was blocked at 503,316,480.
- # The next download will be authorized when bucket hits 492,830,720
- # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
- # needs to drain before another download will be authorized, that will take ~=
- # 2 minutes (10,485,760/89,088/60)
- self.reactor.pump([2.0 * 60.0])
-
- # enough has drained and next request goes through
- channel5 = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcdef",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel5.code == 200
-
- @override_config(
- {
- "remote_media_download_per_second": "50M",
- "remote_media_download_burst_count": "50M",
- }
- )
- @patch(
- "synapse.http.matrixfederationclient.read_multipart_response",
- read_multipart_response_50MiB,
- )
- def test_download_rate_limit_config(self) -> None:
- """
- Test that download rate limit config options are correctly picked up and applied
- """
-
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = 52428800
- resp.headers = Headers(
- {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
- )
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- # first request should go through
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abc",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel.code == 200
-
- # immediate second request should fail
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcd",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel.code == 429
-
- # advance half a second
- self.reactor.pump([0.5])
-
- # request still fails
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcde",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel.code == 429
-
- # advance another half second
- self.reactor.pump([0.5])
-
- # enough has drained from bucket and request is successful
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcdef",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel.code == 200
-
- @override_config(
- {
- "remote_media_download_burst_count": "87M",
- }
- )
- @patch(
- "synapse.http.matrixfederationclient.read_multipart_response",
- read_multipart_response_30MiB,
- )
- def test_download_ratelimit_unknown_length(self) -> None:
- """
- Test that if no content-length is provided, ratelimiting is still applied after
- media is downloaded and length is known
- """
-
- # mock out actually sending the request
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = UNKNOWN_LENGTH
- resp.headers = Headers(
- {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
- )
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- # first 3 will go through (note that 3rd request technically violates rate limit but
- # that since the ratelimiting is applied *after* download it goes through, but next one fails)
- for i in range(3):
- channel2 = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/download/remote.org/abc{i}",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel2.code == 200
-
- # 4th will hit ratelimit
- channel3 = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcd",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel3.code == 429
-
- @override_config({"max_upload_size": "29M"})
- @patch(
- "synapse.http.matrixfederationclient.read_multipart_response",
- read_multipart_response_30MiB,
- )
- def test_max_download_respected(self) -> None:
- """
- Test that the max download size is enforced - note that max download size is determined
- by the max_upload_size
- """
-
- # mock out actually sending the request, returns a 30MiB response
- async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
- resp = MagicMock(spec=IResponse)
- resp.code = 200
- resp.length = 31457280
- resp.headers = Headers(
- {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
- )
- resp.phrase = b"OK"
- return resp
-
- self.client._send_request = _send_request # type: ignore
-
- channel = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/remote.org/abcd",
- shorthand=False,
- access_token=self.tok,
- )
- assert channel.code == 502
- assert channel.json_body["errcode"] == "M_TOO_LARGE"
-
- def test_file_download(self) -> None:
- content = io.BytesIO(b"file_to_stream")
- content_uri = self.get_success(
- self.repo.create_content(
- "text/plain",
- "test_upload",
- content,
- 46,
- UserID.from_string("@user_id:whatever.org"),
- )
- )
- # test with a text file
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/download/test/{content_uri.media_id}",
- shorthand=False,
- access_token=self.tok,
- )
- self.pump()
- self.assertEqual(200, channel.code)
-
-
-test_images = [
- small_png,
- small_png_with_transparency,
- small_lossless_webp,
- empty_file,
- SVG,
-]
-input_values = [(x,) for x in test_images]
-
-
-@parameterized_class(("test_image",), input_values)
-class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
- test_image: ClassVar[TestImage]
- servlets = [
- media.register_servlets,
- login.register_servlets,
- admin.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.fetches: List[
- Tuple[
- "Deferred[Any]",
- str,
- str,
- Optional[QueryParams],
- ]
- ] = []
-
- def federation_get_file(
- destination: str,
- path: str,
- output_stream: BinaryIO,
- download_ratelimiter: Ratelimiter,
- ip_address: Any,
- max_size: int,
- args: Optional[QueryParams] = None,
- retry_on_dns_fail: bool = True,
- ignore_backoff: bool = False,
- follow_redirects: bool = False,
- ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]":
- """A mock for MatrixFederationHttpClient.federation_get_file."""
-
- def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
- ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
- data, response = r
- output_stream.write(data)
- return response
-
- def write_err(f: Failure) -> Failure:
- f.trap(HttpResponseException)
- output_stream.write(f.value.response)
- return f
-
- d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = (
- Deferred()
- )
- self.fetches.append((d, destination, path, args))
- # Note that this callback changes the value held by d.
- d_after_callback = d.addCallbacks(write_to, write_err)
- return make_deferred_yieldable(d_after_callback)
-
- def get_file(
- destination: str,
- path: str,
- output_stream: BinaryIO,
- download_ratelimiter: Ratelimiter,
- ip_address: Any,
- max_size: int,
- args: Optional[QueryParams] = None,
- retry_on_dns_fail: bool = True,
- ignore_backoff: bool = False,
- follow_redirects: bool = False,
- ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
- """A mock for MatrixFederationHttpClient.get_file."""
-
- def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
- ) -> Tuple[int, Dict[bytes, List[bytes]]]:
- data, response = r
- output_stream.write(data)
- return response
-
- def write_err(f: Failure) -> Failure:
- f.trap(HttpResponseException)
- output_stream.write(f.value.response)
- return f
-
- d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
- self.fetches.append((d, destination, path, args))
- # Note that this callback changes the value held by d.
- d_after_callback = d.addCallbacks(write_to, write_err)
- return make_deferred_yieldable(d_after_callback)
-
- # Mock out the homeserver's MatrixFederationHttpClient
- client = Mock()
- client.federation_get_file = federation_get_file
- client.get_file = get_file
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
-
- config = self.default_config()
- config["media_store_path"] = self.media_store_path
- config["max_image_pixels"] = 2000000
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
- config["media_storage_providers"] = [provider_config]
-
- hs = self.setup_test_homeserver(config=config, federation_http_client=client)
-
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.media_repo = hs.get_media_repository()
-
- self.remote = "example.com"
- self.media_id = "12345"
-
- self.user = self.register_user("user", "pass")
- self.tok = self.login("user", "pass")
-
- def _req(
- self, content_disposition: Optional[bytes], include_content_type: bool = True
- ) -> FakeChannel:
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
- shorthand=False,
- await_result=False,
- access_token=self.tok,
- )
- self.pump()
-
- # We've made one fetch, to example.com, using the federation media URL
- self.assertEqual(len(self.fetches), 1)
- self.assertEqual(self.fetches[0][1], "example.com")
- self.assertEqual(
- self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id
- )
- self.assertEqual(
- self.fetches[0][3],
- {"timeout_ms": "20000"},
- )
-
- headers = {
- b"Content-Length": [b"%d" % (len(self.test_image.data))],
- }
-
- if include_content_type:
- headers[b"Content-Type"] = [self.test_image.content_type]
-
- if content_disposition:
- headers[b"Content-Disposition"] = [content_disposition]
-
- self.fetches[0][0].callback(
- (self.test_image.data, (len(self.test_image.data), headers, b"{}"))
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
-
- return channel
-
- def test_handle_missing_content_type(self) -> None:
- channel = self._req(
- b"attachment; filename=out" + self.test_image.extension,
- include_content_type=False,
- )
- headers = channel.headers
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
- )
-
- def test_disposition_filename_ascii(self) -> None:
- """
- If the filename is filename=<ascii> then Synapse will decode it as an
- ASCII string, and use filename= in the response.
- """
- channel = self._req(b"attachment; filename=out" + self.test_image.extension)
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
- )
- self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"),
- [
- (b"inline" if self.test_image.is_inline else b"attachment")
- + b"; filename=out"
- + self.test_image.extension
- ],
- )
-
- def test_disposition_filenamestar_utf8escaped(self) -> None:
- """
- If the filename is filename=*utf8''<utf8 escaped> then Synapse will
- correctly decode it as the UTF-8 string, and use filename* in the
- response.
- """
- filename = parse.quote("\u2603".encode()).encode("ascii")
- channel = self._req(
- b"attachment; filename*=utf-8''" + filename + self.test_image.extension
- )
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
- )
- self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"),
- [
- (b"inline" if self.test_image.is_inline else b"attachment")
- + b"; filename*=utf-8''"
- + filename
- + self.test_image.extension
- ],
- )
-
- def test_disposition_none(self) -> None:
- """
- If there is no filename, Content-Disposition should only
- be a disposition type.
- """
- channel = self._req(None)
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
- )
- self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"),
- [b"inline" if self.test_image.is_inline else b"attachment"],
- )
-
- def test_x_robots_tag_header(self) -> None:
- """
- Tests that the `X-Robots-Tag` header is present, which informs web crawlers
- to not index, archive, or follow links in media.
- """
- channel = self._req(b"attachment; filename=out" + self.test_image.extension)
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"X-Robots-Tag"),
- [b"noindex, nofollow, noarchive, noimageindex"],
- )
-
- def test_cross_origin_resource_policy_header(self) -> None:
- """
- Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
- allowing web clients to embed media from the downloads API.
- """
- channel = self._req(b"attachment; filename=out" + self.test_image.extension)
-
- headers = channel.headers
-
- self.assertEqual(
- headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
- [b"cross-origin"],
- )
-
- def test_unknown_federation_endpoint(self) -> None:
- """
- Test that if the download request to remote federation endpoint returns a 404
- we fall back to the _matrix/media endpoint
- """
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
- shorthand=False,
- await_result=False,
- access_token=self.tok,
- )
- self.pump()
-
- # We've made one fetch, to example.com, using the media URL, and asking
- # the other server not to do a remote fetch
- self.assertEqual(len(self.fetches), 1)
- self.assertEqual(self.fetches[0][1], "example.com")
- self.assertEqual(
- self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
- )
-
- # The result which says the endpoint is unknown.
- unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
- self.fetches[0][0].errback(
- HttpResponseException(404, "NOT FOUND", unknown_endpoint)
- )
-
- self.pump()
-
- # There should now be another request to the _matrix/media/v3/download URL.
- self.assertEqual(len(self.fetches), 2)
- self.assertEqual(self.fetches[1][1], "example.com")
- self.assertEqual(
- self.fetches[1][2],
- f"/_matrix/media/v3/download/example.com/{self.media_id}",
- )
-
- headers = {
- b"Content-Length": [b"%d" % (len(self.test_image.data))],
- }
-
- self.fetches[1][0].callback(
- (self.test_image.data, (len(self.test_image.data), headers))
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
-
- def test_thumbnail_crop(self) -> None:
- """Test that a cropped remote thumbnail is available."""
- self._test_thumbnail(
- "crop",
- self.test_image.expected_cropped,
- expected_found=self.test_image.expected_found,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- def test_thumbnail_scale(self) -> None:
- """Test that a scaled remote thumbnail is available."""
- self._test_thumbnail(
- "scale",
- self.test_image.expected_scaled,
- expected_found=self.test_image.expected_found,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- def test_invalid_type(self) -> None:
- """An invalid thumbnail type is never available."""
- self._test_thumbnail(
- "invalid",
- None,
- expected_found=False,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- @unittest.override_config(
- {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
- )
- def test_no_thumbnail_crop(self) -> None:
- """
- Override the config to generate only scaled thumbnails, but request a cropped one.
- """
- self._test_thumbnail(
- "crop",
- None,
- expected_found=False,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- @unittest.override_config(
- {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
- )
- def test_no_thumbnail_scale(self) -> None:
- """
- Override the config to generate only cropped thumbnails, but request a scaled one.
- """
- self._test_thumbnail(
- "scale",
- None,
- expected_found=False,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- def test_thumbnail_repeated_thumbnail(self) -> None:
- """Test that fetching the same thumbnail works, and deleting the on disk
- thumbnail regenerates it.
- """
- self._test_thumbnail(
- "scale",
- self.test_image.expected_scaled,
- expected_found=self.test_image.expected_found,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- if not self.test_image.expected_found:
- return
-
- # Fetching again should work, without re-requesting the image from the
- # remote.
- params = "?width=32&height=32&method=scale"
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
- shorthand=False,
- await_result=False,
- access_token=self.tok,
- )
- self.pump()
-
- self.assertEqual(channel.code, 200)
- if self.test_image.expected_scaled:
- self.assertEqual(
- channel.result["body"],
- self.test_image.expected_scaled,
- channel.result["body"],
- )
-
- # Deleting the thumbnail on disk then re-requesting it should work as
- # Synapse should regenerate missing thumbnails.
- info = self.get_success(
- self.store.get_cached_remote_media(self.remote, self.media_id)
- )
- assert info is not None
- file_id = info.filesystem_id
-
- thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
- self.remote, file_id
- )
- shutil.rmtree(thumbnail_dir, ignore_errors=True)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
- shorthand=False,
- await_result=False,
- access_token=self.tok,
- )
- self.pump()
-
- self.assertEqual(channel.code, 200)
- if self.test_image.expected_scaled:
- self.assertEqual(
- channel.result["body"],
- self.test_image.expected_scaled,
- channel.result["body"],
- )
-
- def _test_thumbnail(
- self,
- method: str,
- expected_body: Optional[bytes],
- expected_found: bool,
- unable_to_thumbnail: bool = False,
- ) -> None:
- """Test the given thumbnailing method works as expected.
-
- Args:
- method: The thumbnailing method to use (crop, scale).
- expected_body: The expected bytes from thumbnailing, or None if
- test should just check for a valid image.
- expected_found: True if the file should exist on the server, or False if
- a 404/400 is expected.
- unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
- False if the thumbnailing should succeed or a normal 404 is expected.
- """
-
- params = "?width=32&height=32&method=" + method
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
- shorthand=False,
- await_result=False,
- access_token=self.tok,
- )
- self.pump()
- headers = {
- b"Content-Length": [b"%d" % (len(self.test_image.data))],
- b"Content-Type": [self.test_image.content_type],
- }
- self.fetches[0][0].callback(
- (self.test_image.data, (len(self.test_image.data), headers))
- )
- self.pump()
- if expected_found:
- self.assertEqual(channel.code, 200)
-
- self.assertEqual(
- channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
- [b"cross-origin"],
- )
-
- if expected_body is not None:
- self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
- )
- else:
- # ensure that the result is at least some valid image
- Image.open(io.BytesIO(channel.result["body"]))
- elif unable_to_thumbnail:
- # A 400 with a JSON body.
- self.assertEqual(channel.code, 400)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "Cannot find any thumbnails for the requested media ('/_matrix/client/v1/media/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
- },
- )
- else:
- # A 404 with a JSON body.
- self.assertEqual(channel.code, 404)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_NOT_FOUND",
- "error": "Not found '/_matrix/client/v1/media/thumbnail/example.com/12345'",
- },
- )
-
- @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
- def test_same_quality(self, method: str, desired_size: int) -> None:
- """Test that choosing between thumbnails with the same quality rating succeeds.
-
- We are not particular about which thumbnail is chosen."""
-
- content_type = self.test_image.content_type.decode()
- media_repo = self.hs.get_media_repository()
- thumbnail_provider = ThumbnailProvider(
- self.hs, media_repo, media_repo.media_storage
- )
-
- self.assertIsNotNone(
- thumbnail_provider._select_thumbnail(
- desired_width=desired_size,
- desired_height=desired_size,
- desired_method=method,
- desired_type=content_type,
- # Provide two identical thumbnails which are guaranteed to have the same
- # quality rating.
- thumbnail_infos=[
- ThumbnailInfo(
- width=32,
- height=32,
- method=method,
- type=content_type,
- length=256,
- ),
- ThumbnailInfo(
- width=32,
- height=32,
- method=method,
- type=content_type,
- length=256,
- ),
- ],
- file_id=f"image{self.test_image.extension.decode()}",
- url_cache=False,
- server_name=None,
- )
- )
-
-
-configs = [
- {"extra_config": {"dynamic_thumbnails": True}},
- {"extra_config": {"dynamic_thumbnails": False}},
-]
-
-
-@parameterized_class(configs)
-class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
- extra_config: Dict[str, Any]
- servlets = [
- media.register_servlets,
- login.register_servlets,
- admin.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- self.clock = clock
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
- config["media_store_path"] = self.media_store_path
- config["enable_authenticated_media"] = True
-
- provider_config = {
- "module": "synapse.media.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
-
- config["media_storage_providers"] = [provider_config]
- config.update(self.extra_config)
-
- return self.setup_test_homeserver(config=config)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.repo = hs.get_media_repository()
- self.client = hs.get_federation_http_client()
- self.store = hs.get_datastores().main
- self.user = self.register_user("user", "pass")
- self.tok = self.login("user", "pass")
-
- def create_resource_dict(self) -> Dict[str, Resource]:
- resources = super().create_resource_dict()
- resources["/_matrix/media"] = self.hs.get_media_repository_resource()
- return resources
-
- def test_authenticated_media(self) -> None:
- # upload some local media with authentication on
- channel = self.make_request(
- "POST",
- "_matrix/media/v3/upload?filename=test_png_upload",
- SMALL_PNG,
- self.tok,
- shorthand=False,
- content_type=b"image/png",
- custom_headers=[("Content-Length", str(67))],
- )
- self.assertEqual(channel.code, 200)
- res = channel.json_body.get("content_uri")
- assert res is not None
- uri = res.split("mxc://")[1]
-
- # request media over authenticated endpoint, should be found
- channel2 = self.make_request(
- "GET",
- f"_matrix/client/v1/media/download/{uri}",
- access_token=self.tok,
- shorthand=False,
- )
- self.assertEqual(channel2.code, 200)
-
- # request same media over unauthenticated media, should raise 404 not found
- channel3 = self.make_request(
- "GET", f"_matrix/media/v3/download/{uri}", shorthand=False
- )
- self.assertEqual(channel3.code, 404)
-
- # check thumbnails as well
- params = "?width=32&height=32&method=crop"
- channel4 = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/{uri}{params}",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel4.code, 200)
-
- params = "?width=32&height=32&method=crop"
- channel5 = self.make_request(
- "GET",
- f"/_matrix/media/r0/thumbnail/{uri}{params}",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel5.code, 404)
-
- # Inject a piece of remote media.
- file_id = "abcdefg12345"
- file_info = FileInfo(server_name="lonelyIsland", file_id=file_id)
-
- media_storage = self.hs.get_media_repository().media_storage
-
- ctx = media_storage.store_into_file(file_info)
- (f, fname) = self.get_success(ctx.__aenter__())
- f.write(SMALL_PNG)
- self.get_success(ctx.__aexit__(None, None, None))
-
- # we write the authenticated status when storing media, so this should pick up
- # config and authenticate the media
- self.get_success(
- self.store.store_cached_remote_media(
- origin="lonelyIsland",
- media_id="52",
- media_type="image/png",
- media_length=1,
- time_now_ms=self.clock.time_msec(),
- upload_name="remote_test.png",
- filesystem_id=file_id,
- )
- )
-
- # ensure we have thumbnails for the non-dynamic code path
- if self.extra_config == {"dynamic_thumbnails": False}:
- self.get_success(
- self.repo._generate_thumbnails(
- "lonelyIsland", "52", file_id, "image/png"
- )
- )
-
- channel6 = self.make_request(
- "GET",
- "_matrix/client/v1/media/download/lonelyIsland/52",
- access_token=self.tok,
- shorthand=False,
- )
- self.assertEqual(channel6.code, 200)
-
- channel7 = self.make_request(
- "GET", f"_matrix/media/v3/download/{uri}", shorthand=False
- )
- self.assertEqual(channel7.code, 404)
-
- params = "?width=32&height=32&method=crop"
- channel8 = self.make_request(
- "GET",
- f"/_matrix/client/v1/media/thumbnail/lonelyIsland/52{params}",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel8.code, 200)
-
- channel9 = self.make_request(
- "GET",
- f"/_matrix/media/r0/thumbnail/lonelyIsland/52{params}",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel9.code, 404)
-
- # Inject a piece of local media that isn't authenticated
- file_id = "abcdefg123456"
- file_info = FileInfo(None, file_id=file_id)
-
- ctx = media_storage.store_into_file(file_info)
- (f, fname) = self.get_success(ctx.__aenter__())
- f.write(SMALL_PNG)
- self.get_success(ctx.__aexit__(None, None, None))
-
- self.get_success(
- self.store.db_pool.simple_insert(
- "local_media_repository",
- {
- "media_id": "abcdefg123456",
- "media_type": "image/png",
- "created_ts": self.clock.time_msec(),
- "upload_name": "test_local",
- "media_length": 1,
- "user_id": "someone",
- "url_cache": None,
- "authenticated": False,
- },
- desc="store_local_media",
- )
- )
-
- # check that unauthenticated media is still available over both endpoints
- channel9 = self.make_request(
- "GET",
- "/_matrix/client/v1/media/download/test/abcdefg123456",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel9.code, 200)
-
- channel10 = self.make_request(
- "GET",
- "/_matrix/media/r0/download/test/abcdefg123456",
- shorthand=False,
- access_token=self.tok,
- )
- self.assertEqual(channel10.code, 200)
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
index f8a56c80ca..99d0feb10d 100644
--- a/tests/rest/client/test_models.py
+++ b/tests/rest/client/test_models.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -24,7 +23,7 @@ from typing import TYPE_CHECKING
from typing_extensions import Literal
from synapse._pydantic_compat import HAS_PYDANTIC_V2
-from synapse.types.rest.client import EmailRequestTokenBody
+from synapse.rest.client.models import EmailRequestTokenBody
if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import BaseModel, ValidationError
diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 637722ca0a..e25f4e4175 100644
--- a/tests/rest/client/test_mutual_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Half-Shot
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py
index e4b0455ce8..8f1cbcd5db 100644
--- a/tests/rest/client/test_notifications.py
+++ b/tests/rest/client/test_notifications.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -18,7 +17,6 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import List, Optional, Tuple
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -49,14 +47,6 @@ class HTTPPusherTests(HomeserverTestCase):
self.sync_handler = homeserver.get_sync_handler()
self.auth_handler = homeserver.get_auth_handler()
- self.user_id = self.register_user("user", "pass")
- self.access_token = self.login("user", "pass")
- self.other_user_id = self.register_user("otheruser", "pass")
- self.other_access_token = self.login("otheruser", "pass")
-
- # Create a room
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
-
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
@@ -70,22 +60,32 @@ class HTTPPusherTests(HomeserverTestCase):
"""
Local users will get notified for invites
"""
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
# Check we start with no pushes
- self._request_notifications(from_token=None, limit=1, expected_count=0)
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=other_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(len(channel.json_body["notifications"]), 0, channel.json_body)
# Send an invite
- self.helper.invite(
- room=self.room_id,
- src=self.user_id,
- targ=self.other_user_id,
- tok=self.access_token,
- )
+ self.helper.invite(room=room, src=user_id, targ=other_user_id, tok=access_token)
# We should have a notification now
channel = self.make_request(
"GET",
"/notifications",
- access_token=self.other_access_token,
+ access_token=other_access_token,
)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
@@ -94,139 +94,3 @@ class HTTPPusherTests(HomeserverTestCase):
"invite",
channel.json_body,
)
-
- def test_pagination_of_notifications(self) -> None:
- """
- Check that pagination of notifications works.
- """
- # Check we start with no pushes
- self._request_notifications(from_token=None, limit=1, expected_count=0)
-
- # Send an invite and have the other user join the room.
- self.helper.invite(
- room=self.room_id,
- src=self.user_id,
- targ=self.other_user_id,
- tok=self.access_token,
- )
- self.helper.join(self.room_id, self.other_user_id, tok=self.other_access_token)
-
- # Send 5 messages in the room and note down their event IDs.
- sent_event_ids = []
- for _ in range(5):
- resp = self.helper.send_event(
- self.room_id,
- "m.room.message",
- {"body": "honk", "msgtype": "m.text"},
- tok=self.access_token,
- )
- sent_event_ids.append(resp["event_id"])
-
- # We expect to get notifications for messages in reverse order.
- # So reverse this list of event IDs to make it easier to compare
- # against later.
- sent_event_ids.reverse()
-
- # We should have a few notifications now. Let's try and fetch the first 2.
- notification_event_ids, _ = self._request_notifications(
- from_token=None, limit=2, expected_count=2
- )
-
- # Check we got the expected event IDs back.
- self.assertEqual(notification_event_ids, sent_event_ids[:2])
-
- # Try requesting again without a 'from' query parameter. We should get the
- # same two notifications back.
- notification_event_ids, next_token = self._request_notifications(
- from_token=None, limit=2, expected_count=2
- )
- self.assertEqual(notification_event_ids, sent_event_ids[:2])
-
- # Ask for the next 5 notifications, though there should only be
- # 4 remaining; the next 3 messages and the invite.
- #
- # We need to use the "next_token" from the response as the "from"
- # query parameter in the next request in order to paginate.
- notification_event_ids, next_token = self._request_notifications(
- from_token=next_token, limit=5, expected_count=4
- )
- # Ensure we chop off the invite on the end.
- notification_event_ids = notification_event_ids[:-1]
- self.assertEqual(notification_event_ids, sent_event_ids[2:])
-
- def _request_notifications(
- self, from_token: Optional[str], limit: int, expected_count: int
- ) -> Tuple[List[str], str]:
- """
- Make a request to /notifications to get the latest events to be notified about.
-
- Only the event IDs are returned. The request is made by the "other user".
-
- Args:
- from_token: An optional starting parameter.
- limit: The maximum number of results to return.
- expected_count: The number of events to expect in the response.
-
- Returns:
- A list of event IDs that the client should be notified about.
- Events are returned newest-first.
- """
- # Construct the request path.
- path = f"/notifications?limit={limit}"
- if from_token is not None:
- path += f"&from={from_token}"
-
- channel = self.make_request(
- "GET",
- path,
- access_token=self.other_access_token,
- )
-
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- len(channel.json_body["notifications"]), expected_count, channel.json_body
- )
-
- # Extract the necessary data from the response.
- next_token = channel.json_body["next_token"]
- event_ids = [
- event["event"]["event_id"] for event in channel.json_body["notifications"]
- ]
-
- return event_ids, next_token
-
- def test_parameters(self) -> None:
- """
- Test that appropriate errors are returned when query parameters are malformed.
- """
- # Test that no parameters are required.
- channel = self.make_request(
- "GET",
- "/notifications",
- access_token=self.other_access_token,
- )
- self.assertEqual(channel.code, 200)
-
- # Test that limit cannot be negative
- channel = self.make_request(
- "GET",
- "/notifications?limit=-1",
- access_token=self.other_access_token,
- )
- self.assertEqual(channel.code, 400)
-
- # Test that the 'limit' parameter must be an integer.
- channel = self.make_request(
- "GET",
- "/notifications?limit=foobar",
- access_token=self.other_access_token,
- )
- self.assertEqual(channel.code, 400)
-
- # Test that the 'from' parameter must be an integer.
- channel = self.make_request(
- "GET",
- "/notifications?from=osborne",
- access_token=self.other_access_token,
- )
- self.assertEqual(channel.code, 400)
diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index f0ef733f7b..08d1cf56f2 100644
--- a/tests/rest/client/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 1584c2e96c..634cda9262 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index f98f3f77aa..b9852928c0 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index 9da0e7982f..0ee66b6a84 100644
--- a/tests/rest/client/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_read_marker.py b/tests/rest/client/test_read_marker.py
index 0b4ad685b3..ed0ac2af6e 100644
--- a/tests/rest/client/test_read_marker.py
+++ b/tests/rest/client/test_read_marker.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 Beeper
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -78,7 +77,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{room_id}/read_markers",
+ "/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_1,
},
@@ -90,7 +89,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
event_id_2 = send_message()
channel = self.make_request(
"POST",
- f"/rooms/{room_id}/read_markers",
+ "/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_2,
},
@@ -123,7 +122,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{room_id}/read_markers",
+ "/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_1,
},
@@ -142,7 +141,7 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase):
event_id_2 = send_message()
channel = self.make_request(
"POST",
- f"/rooms/{room_id}/read_markers",
+ "/rooms/!abc:beep/read_markers",
content={
"m.fully_read": event_id_2,
},
diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py
index f0648289f1..9fe49a12f6 100644
--- a/tests/rest/client/test_receipts.py
+++ b/tests/rest/client/test_receipts.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index b25e184786..2afb3f065a 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 694f143eff..13f27d13c9 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -22,7 +20,6 @@
import datetime
import os
from typing import Any, Dict, List, Tuple
-from unittest.mock import AsyncMock
import pkg_resources
@@ -43,7 +40,6 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock
from tests.unittest import override_config
@@ -60,13 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
config["allow_guest_access"] = True
return config
- def make_homeserver(
- self, reactor: ThreadedMemoryReactorClock, clock: Clock
- ) -> HomeServer:
- hs = super().make_homeserver(reactor, clock)
- hs.get_send_email_handler()._sendmail = AsyncMock()
- return hs
-
def test_POST_appservice_registration_valid(self) -> None:
user_id = "@as_user_kermit:test"
as_token = "i_am_an_app_service"
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index f5a7602d0a..e0529d88c5 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -35,6 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils.event_injection import inject_event
+from tests.unittest import override_config
class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -956,6 +956,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
class RecursiveRelationTestCase(BaseRelationsTestCase):
+ @override_config({"experimental_features": {"msc3981_recurse_relations": True}})
def test_recursive_relations(self) -> None:
"""Generate a complex, multi-level relationship tree and query it."""
# Create a thread with a few messages in it.
@@ -1001,7 +1002,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}"
- "?dir=f&limit=20&recurse=true",
+ "?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1022,6 +1023,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase):
],
)
+ @override_config({"experimental_features": {"msc3981_recurse_relations": True}})
def test_recursive_relations_with_filter(self) -> None:
"""The event_type and rel_type still apply."""
# Create a thread with a few messages in it.
@@ -1049,7 +1051,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- "?dir=f&limit=20&recurse=true",
+ "?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1062,7 +1064,7 @@ class RecursiveRelationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction"
- "?dir=f&limit=20&recurse=true",
+ "?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index 0ab754a11a..07e45f14f9 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -1,8 +1,7 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023-2024 New Vector, Ltd
+# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@@ -19,23 +18,16 @@
#
#
-from typing import Dict
-from urllib.parse import urlparse
-
from twisted.test.proto_helpers import MemoryReactor
-from twisted.web.resource import Resource
from synapse.rest.client import rendezvous
-from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
-from tests.utils import HAS_AUTHLIB
-msc3886_endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
-msc4108_endpoint = "/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
+endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
class RendezvousServletTestCase(unittest.HomeserverTestCase):
@@ -47,430 +39,12 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver()
return self.hs
- def create_resource_dict(self) -> Dict[str, Resource]:
- return {
- **super().create_resource_dict(),
- "/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs),
- }
-
def test_disabled(self) -> None:
- channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 404)
- channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
self.assertEqual(channel.code, 404)
@override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
- def test_msc3886_redirect(self) -> None:
- channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
+ def test_redirect(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
self.assertEqual(channel.code, 307)
self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
-
- @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc4108_delegation_endpoint": "https://asd",
- "msc3861": {
- "enabled": True,
- "issuer": "https://issuer",
- "client_id": "client_id",
- "client_auth_method": "client_secret_post",
- "client_secret": "client_secret",
- "admin_token": "admin_token_value",
- },
- },
- }
- )
- def test_msc4108_delegation(self) -> None:
- channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 307)
- self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"])
-
- @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc4108_enabled": True,
- "msc3861": {
- "enabled": True,
- "issuer": "https://issuer",
- "client_id": "client_id",
- "client_auth_method": "client_secret_post",
- "client_secret": "client_secret",
- "admin_token": "admin_token_value",
- },
- },
- }
- )
- def test_msc4108(self) -> None:
- """
- Test the MSC4108 rendezvous endpoint, including:
- - Creating a session
- - Getting the data back
- - Updating the data
- - Deleting the data
- - ETag handling
- """
- # We can post arbitrary data to the endpoint
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
- self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"])
- headers = dict(channel.headers.getAllRawHeaders())
- self.assertIn(b"ETag", headers)
- self.assertIn(b"Expires", headers)
- self.assertEqual(headers[b"Content-Type"], [b"application/json"])
- self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
- self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
- self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
- self.assertEqual(headers[b"Pragma"], [b"no-cache"])
- self.assertIn("url", channel.json_body)
- self.assertTrue(channel.json_body["url"].startswith("https://"))
-
- url = urlparse(channel.json_body["url"])
- session_endpoint = url.path
- etag = headers[b"ETag"][0]
-
- # We can get the data back
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
-
- self.assertEqual(channel.code, 200)
- headers = dict(channel.headers.getAllRawHeaders())
- self.assertEqual(headers[b"ETag"], [etag])
- self.assertIn(b"Expires", headers)
- self.assertEqual(headers[b"Content-Type"], [b"text/plain"])
- self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
- self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
- self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
- self.assertEqual(headers[b"Pragma"], [b"no-cache"])
- self.assertEqual(channel.text_body, "foo=bar")
-
- # We can make sure the data hasn't changed
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- custom_headers=[("If-None-Match", etag)],
- )
-
- self.assertEqual(channel.code, 304)
-
- # We can update the data
- channel = self.make_request(
- "PUT",
- session_endpoint,
- "foo=baz",
- content_type=b"text/plain",
- access_token=None,
- custom_headers=[("If-Match", etag)],
- )
-
- self.assertEqual(channel.code, 202)
- headers = dict(channel.headers.getAllRawHeaders())
- old_etag = etag
- new_etag = headers[b"ETag"][0]
-
- # If we try to update it again with the old etag, it should fail
- channel = self.make_request(
- "PUT",
- session_endpoint,
- "bar=baz",
- content_type=b"text/plain",
- access_token=None,
- custom_headers=[("If-Match", old_etag)],
- )
-
- self.assertEqual(channel.code, 412)
- self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN")
- self.assertEqual(
- channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE"
- )
-
- # If we try to get with the old etag, we should get the updated data
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- custom_headers=[("If-None-Match", old_etag)],
- )
-
- self.assertEqual(channel.code, 200)
- headers = dict(channel.headers.getAllRawHeaders())
- self.assertEqual(headers[b"ETag"], [new_etag])
- self.assertEqual(channel.text_body, "foo=baz")
-
- # We can delete the data
- channel = self.make_request(
- "DELETE",
- session_endpoint,
- access_token=None,
- )
-
- self.assertEqual(channel.code, 204)
-
- # If we try to get the data again, it should fail
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
-
- self.assertEqual(channel.code, 404)
- self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
-
- @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc4108_enabled": True,
- "msc3861": {
- "enabled": True,
- "issuer": "https://issuer",
- "client_id": "client_id",
- "client_auth_method": "client_secret_post",
- "client_secret": "client_secret",
- "admin_token": "admin_token_value",
- },
- },
- }
- )
- def test_msc4108_expiration(self) -> None:
- """
- Test that entries are evicted after a TTL.
- """
- # Start a new session
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
- session_endpoint = urlparse(channel.json_body["url"]).path
-
- # Sanity check that we can get the data back
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.text_body, "foo=bar")
-
- # Advance the clock, TTL of entries is 1 minute
- self.reactor.advance(60)
-
- # Get the data back, it should be gone
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
- self.assertEqual(channel.code, 404)
-
- @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc4108_enabled": True,
- "msc3861": {
- "enabled": True,
- "issuer": "https://issuer",
- "client_id": "client_id",
- "client_auth_method": "client_secret_post",
- "client_secret": "client_secret",
- "admin_token": "admin_token_value",
- },
- },
- }
- )
- def test_msc4108_capacity(self) -> None:
- """
- Test that a capacity limit is enforced on the rendezvous sessions, as old
- entries are evicted at an interval when the limit is reached.
- """
- # Start a new session
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
- session_endpoint = urlparse(channel.json_body["url"]).path
-
- # Sanity check that we can get the data back
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.text_body, "foo=bar")
-
- # Start a lot of new sessions
- for _ in range(100):
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
-
- # Get the data back, it should still be there, as the eviction hasn't run yet
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
-
- self.assertEqual(channel.code, 200)
-
- # Advance the clock, as it will trigger the eviction
- self.reactor.advance(1)
-
- # Get the data back, it should be gone
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
-
- @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc4108_enabled": True,
- "msc3861": {
- "enabled": True,
- "issuer": "https://issuer",
- "client_id": "client_id",
- "client_auth_method": "client_secret_post",
- "client_secret": "client_secret",
- "admin_token": "admin_token_value",
- },
- },
- }
- )
- def test_msc4108_hard_capacity(self) -> None:
- """
- Test that a hard capacity limit is enforced on the rendezvous sessions, as old
- entries are evicted immediately when the limit is reached.
- """
- # Start a new session
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
- session_endpoint = urlparse(channel.json_body["url"]).path
- # We advance the clock to make sure that this entry is the "lowest" in the session list
- self.reactor.advance(1)
-
- # Sanity check that we can get the data back
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
- self.assertEqual(channel.code, 200)
- self.assertEqual(channel.text_body, "foo=bar")
-
- # Start a lot of new sessions
- for _ in range(200):
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
-
- # Get the data back, it should already be gone as we hit the hard limit
- channel = self.make_request(
- "GET",
- session_endpoint,
- access_token=None,
- )
-
- self.assertEqual(channel.code, 404)
-
- @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc4108_enabled": True,
- "msc3861": {
- "enabled": True,
- "issuer": "https://issuer",
- "client_id": "client_id",
- "client_auth_method": "client_secret_post",
- "client_secret": "client_secret",
- "admin_token": "admin_token_value",
- },
- },
- }
- )
- def test_msc4108_content_type(self) -> None:
- """
- Test that the content-type is restricted to text/plain.
- """
- # We cannot post invalid content-type arbitrary data to the endpoint
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_is_form=True,
- access_token=None,
- )
- self.assertEqual(channel.code, 400)
- self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
-
- # Make a valid request
- channel = self.make_request(
- "POST",
- msc4108_endpoint,
- "foo=bar",
- content_type=b"text/plain",
- access_token=None,
- )
- self.assertEqual(channel.code, 201)
- url = urlparse(channel.json_body["url"])
- session_endpoint = url.path
- headers = dict(channel.headers.getAllRawHeaders())
- etag = headers[b"ETag"][0]
-
- # We can't update the data with invalid content-type
- channel = self.make_request(
- "PUT",
- session_endpoint,
- "foo=baz",
- content_is_form=True,
- access_token=None,
- custom_headers=[("If-Match", etag)],
- )
- self.assertEqual(channel.code, 400)
- self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
diff --git a/tests/rest/client/test_reporting.py b/tests/rest/client/test_report_event.py
index 009deb9cb0..41c03b6f68 100644
--- a/tests/rest/client/test_reporting.py
+++ b/tests/rest/client/test_report_event.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 Callum Brown
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -22,7 +21,7 @@
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.rest.client import login, reporting, room
+from synapse.rest.client import login, report_event, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -35,7 +34,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
- reporting.register_servlets,
+ report_event.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -139,92 +138,3 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
"POST", self.report_path, data, access_token=self.other_user_tok
)
self.assertEqual(response_status, channel.code, msg=channel.result["body"])
-
-
-class ReportRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- reporting.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.other_user = self.register_user("user", "pass")
- self.other_user_tok = self.login("user", "pass")
-
- self.room_id = self.helper.create_room_as(
- self.other_user, tok=self.other_user_tok, is_public=True
- )
- self.report_path = (
- f"/_matrix/client/unstable/org.matrix.msc4151/rooms/{self.room_id}/report"
- )
-
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
- def test_reason_str(self) -> None:
- data = {"reason": "this makes me sad"}
- self._assert_status(200, data)
-
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
- def test_no_reason(self) -> None:
- data = {"not_reason": "for typechecking"}
- self._assert_status(400, data)
-
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
- def test_reason_nonstring(self) -> None:
- data = {"reason": 42}
- self._assert_status(400, data)
-
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
- def test_reason_null(self) -> None:
- data = {"reason": None}
- self._assert_status(400, data)
-
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
- def test_cannot_report_nonexistent_room(self) -> None:
- """
- Tests that we don't accept event reports for rooms which do not exist.
- """
- channel = self.make_request(
- "POST",
- "/_matrix/client/unstable/org.matrix.msc4151/rooms/!bloop:example.org/report",
- {"reason": "i am very sad"},
- access_token=self.other_user_tok,
- shorthand=False,
- )
- self.assertEqual(404, channel.code, msg=channel.result["body"])
- self.assertEqual(
- "Room does not exist",
- channel.json_body["error"],
- msg=channel.result["body"],
- )
-
- def _assert_status(self, response_status: int, data: JsonDict) -> None:
- channel = self.make_request(
- "POST",
- self.report_path,
- data,
- access_token=self.other_user_tok,
- shorthand=False,
- )
- self.assertEqual(response_status, channel.code, msg=channel.result["body"])
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 1e5a1b0a4d..09a5d64349 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -163,11 +163,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
- filter_events_for_client(
- storage_controllers,
- self.user_id,
- events,
- )
+ filter_events_for_client(storage_controllers, self.user_id, events)
)
# We should only get one event back.
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c559dfda83..0e71cdcd88 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1,9 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -48,16 +45,7 @@ from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
-from synapse.rest.client import (
- account,
- directory,
- knock,
- login,
- profile,
- register,
- room,
- sync,
-)
+from synapse.rest.client import account, directory, login, profile, register, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
@@ -102,7 +90,6 @@ class RoomPermissionsTestCase(RoomBase):
rmcreator_id = "@notme:red"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store_controllers = hs.get_storage_controllers()
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
self.uncreated_rmid = "!aa:test"
@@ -492,23 +479,6 @@ class RoomPermissionsTestCase(RoomBase):
expect_code=HTTPStatus.OK,
)
- def test_default_call_invite_power_level(self) -> None:
- pl_event = self.get_success(
- self.store_controllers.state.get_current_state_event(
- self.created_public_rmid, EventTypes.PowerLevels, ""
- )
- )
- assert pl_event is not None
- self.assertEqual(50, pl_event.content.get("m.call.invite"))
-
- private_pl_event = self.get_success(
- self.store_controllers.state.get_current_state_event(
- self.created_rmid, EventTypes.PowerLevels, ""
- )
- )
- assert private_pl_event is not None
- self.assertEqual(None, private_pl_event.content.get("m.call.invite"))
-
class RoomStateTestCase(RoomBase):
"""Tests /rooms/$room_id/state."""
@@ -742,7 +712,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -755,7 +725,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(35, channel.resource_usage.db_txn_count)
+ self.assertEqual(34, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1163,7 +1133,6 @@ class RoomJoinTestCase(RoomBase):
admin.register_servlets,
login.register_servlets,
room.register_servlets,
- knock.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -1177,8 +1146,6 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- self.store = hs.get_datastores().main
-
def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
@@ -1252,9 +1219,9 @@ class RoomJoinTestCase(RoomBase):
"""
# Register a dummy callback. Make it allow all room joins for now.
- return_value: Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes] = (
- synapse.module_api.NOT_SPAM
- )
+ return_value: Union[
+ Literal["NOT_SPAM"], Tuple[Codes, dict], Codes
+ ] = synapse.module_api.NOT_SPAM
async def user_may_join_room(
userid: str,
@@ -1329,57 +1296,6 @@ class RoomJoinTestCase(RoomBase):
expect_additional_fields=return_value[1],
)
- def test_suspended_user_cannot_join_room(self) -> None:
- # set the user as suspended
- self.get_success(self.store.set_user_suspended_status(self.user2, True))
-
- channel = self.make_request(
- "POST", f"/join/{self.room1}", access_token=self.tok2
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- channel = self.make_request(
- "POST", f"/rooms/{self.room1}/join", access_token=self.tok2
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- def test_suspended_user_cannot_knock_on_room(self) -> None:
- # set the user as suspended
- self.get_success(self.store.set_user_suspended_status(self.user2, True))
-
- channel = self.make_request(
- "POST",
- f"/_matrix/client/v3/knock/{self.room1}",
- access_token=self.tok2,
- content={},
- shorthand=False,
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- def test_suspended_user_cannot_invite_to_room(self) -> None:
- # set the user as suspended
- self.get_success(self.store.set_user_suspended_status(self.user1, True))
-
- # first user invites second user
- channel = self.make_request(
- "POST",
- f"/rooms/{self.room1}/invite",
- access_token=self.tok1,
- content={"user_id": self.user2},
- )
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -1745,9 +1661,9 @@ class RoomMessagesTestCase(RoomBase):
expected_fields: dict,
) -> None:
class SpamCheck:
- mock_return_value: Union[str, bool, Codes, Tuple[Codes, JsonDict], bool] = (
- "NOT_SPAM"
- )
+ mock_return_value: Union[
+ str, bool, Codes, Tuple[Codes, JsonDict], bool
+ ] = "NOT_SPAM"
mock_content: Optional[JsonDict] = None
async def check_event_for_spam(
@@ -2238,58 +2154,6 @@ class RoomMessageListTestCase(RoomBase):
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
- def test_room_message_filter_query_validation(self) -> None:
- # Test json validation in (filter) query parameter.
- # Does not test the validity of the filter, only the json validation.
-
- # Check Get with valid json filter parameter, expect 200.
- valid_filter_str = '{"types": ["m.room.message"]}'
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}",
- )
-
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
- invalid_filter_str = "}}}{}"
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}",
- )
-
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
- self.assertEqual(
- channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
- )
-
-
-class RoomMessageFilterTestCase(RoomBase):
- """Tests /rooms/$room_id/messages REST events."""
-
- user_id = "@sid1:red"
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.room_id = self.helper.create_room_as(self.user_id)
-
- def test_room_message_filter_wildcard(self) -> None:
- # Send a first message in the room, which will be removed by the purge.
- self.helper.send(self.room_id, "message 1", type="f.message.1")
- self.helper.send(self.room_id, "message 1", type="f.message.2")
- self.helper.send(self.room_id, "not returned in filter")
- channel = self.make_request(
- "GET",
- "/rooms/%s/messages?access_token=x&dir=b&filter=%s"
- % (
- self.room_id,
- json.dumps({"types": ["f.message.*"]}),
- ),
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- chunk = channel.json_body["chunk"]
- self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
-
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -3301,33 +3165,6 @@ class ContextTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
self.assertEqual(events_after[1].get("content"), {}, events_after[1])
- def test_room_event_context_filter_query_validation(self) -> None:
- # Test json validation in (filter) query parameter.
- # Does not test the validity of the filter, only the json validation.
- event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]
-
- # Check Get with valid json filter parameter, expect 200.
- valid_filter_str = '{"types": ["m.room.message"]}'
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}",
- access_token=self.tok,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
- invalid_filter_str = "}}}{}"
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}",
- access_token=self.tok,
- )
-
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
- self.assertEqual(
- channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
- )
-
class RoomAliasListTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -3819,108 +3656,3 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
# Make sure the outlier event is not returned
self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id)
-
-
-class UserSuspensionTests(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- profile.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user1 = self.register_user("thomas", "hackme")
- self.tok1 = self.login("thomas", "hackme")
-
- self.user2 = self.register_user("teresa", "hackme")
- self.tok2 = self.login("teresa", "hackme")
-
- self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- self.store = hs.get_datastores().main
-
- def test_suspended_user_cannot_send_message_to_room(self) -> None:
- # set the user as suspended
- self.get_success(self.store.set_user_suspended_status(self.user1, True))
-
- channel = self.make_request(
- "PUT",
- f"/rooms/{self.room1}/send/m.room.message/1",
- access_token=self.tok1,
- content={"body": "hello", "msgtype": "m.text"},
- )
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- def test_suspended_user_cannot_change_profile_data(self) -> None:
- # set the user as suspended
- self.get_success(self.store.set_user_suspended_status(self.user1, True))
-
- channel = self.make_request(
- "PUT",
- f"/_matrix/client/v3/profile/{self.user1}/avatar_url",
- access_token=self.tok1,
- content={"avatar_url": "mxc://matrix.org/wefh34uihSDRGhw34"},
- shorthand=False,
- )
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- channel2 = self.make_request(
- "PUT",
- f"/_matrix/client/v3/profile/{self.user1}/displayname",
- access_token=self.tok1,
- content={"displayname": "something offensive"},
- shorthand=False,
- )
- self.assertEqual(
- channel2.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- def test_suspended_user_cannot_redact_messages_other_than_their_own(self) -> None:
- # first user sends message
- self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok2)
- res = self.helper.send_event(
- self.room1,
- "m.room.message",
- {"body": "hello", "msgtype": "m.text"},
- tok=self.tok2,
- )
- event_id = res["event_id"]
-
- # second user sends message
- self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok1)
- res2 = self.helper.send_event(
- self.room1,
- "m.room.message",
- {"body": "bad_message", "msgtype": "m.text"},
- tok=self.tok1,
- )
- event_id2 = res2["event_id"]
-
- # set the second user as suspended
- self.get_success(self.store.set_user_suspended_status(self.user1, True))
-
- # second user can't redact first user's message
- channel = self.make_request(
- "PUT",
- f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id}/1",
- access_token=self.tok1,
- content={"reason": "bogus"},
- shorthand=False,
- )
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
-
- # but can redact their own
- channel = self.make_request(
- "PUT",
- f"/_matrix/client/v3/rooms/{self.room1}/redact/{event_id2}/1",
- access_token=self.tok1,
- content={"reason": "bogus"},
- shorthand=False,
- )
- self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 5ef501c6d5..0a7676e566 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -18,39 +17,15 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
-from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
-@parameterized_class(
- ("sync_endpoint", "experimental_features"),
- [
- ("/sync", {}),
- (
- "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
- # Enable sliding sync
- {"msc3575_enabled": True},
- ),
- ],
-)
class SendToDeviceTestCase(HomeserverTestCase):
- """
- Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
-
- Attributes:
- sync_endpoint: The endpoint under test to use for syncing.
- experimental_features: The experimental features homeserver config to use.
- """
-
- sync_endpoint: str
- experimental_features: JsonDict
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -58,11 +33,6 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["experimental_features"] = self.experimental_features
- return config
-
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -83,7 +53,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
- channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
+ channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -96,19 +66,15 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
- # it should re-appear if we do another sync because the to-device message is not
- # deleted until we acknowledge it by sending a `?since=...` parameter in the
- # next sync request corresponding to the `next_batch` value from the response.
- channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
+ # it should re-appear if we do another sync
+ channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={sync_token}",
- access_token=user2_tok,
+ "GET", f"/sync?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -132,19 +98,15 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
- # now sync: we should get two of the three (because burst_count=2)
- channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
+ # now sync: we should get two of the three
+ channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
- {
- "sender": user1,
- "type": "m.room_key_request",
- "content": {"idx": i},
- },
+ {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
)
sync_token = channel.json_body["next_batch"]
@@ -162,9 +124,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={sync_token}",
- access_token=user2_tok,
+ "GET", f"/sync?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -198,7 +158,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
- channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
+ channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -232,9 +192,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={sync_token}",
- access_token=user2_tok,
+ "GET", f"/sync?since={sync_token}", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -258,7 +216,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
- channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
+ channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -274,9 +232,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
- access_token=user2_tok,
+ "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -284,9 +240,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
- access_token=user2_tok,
+ "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 2287f233b4..38eb2daf72 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 63df31ec75..d52f1cc34e 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -19,10 +18,9 @@
#
#
import json
-import logging
from typing import List
-from parameterized import parameterized, parameterized_class
+from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
@@ -44,8 +42,6 @@ from tests.federation.transport.test_knocking import (
)
from tests.server import TimedOutException
-logger = logging.getLogger(__name__)
-
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
@@ -691,180 +687,24 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
-@parameterized_class(
- ("sync_endpoint", "experimental_features"),
- [
- ("/sync", {}),
- (
- "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
- # Enable sliding sync
- {"msc3575_enabled": True},
- ),
- ],
-)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
- """
- Tests regarding device list (`device_lists`) changes.
-
- Attributes:
- sync_endpoint: The endpoint under test to use for syncing.
- experimental_features: The experimental features homeserver config to use.
- """
-
- sync_endpoint: str
- experimental_features: JsonDict
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
- room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["experimental_features"] = self.experimental_features
- return config
-
- def test_receiving_local_device_list_changes(self) -> None:
- """Tests that a local users that share a room receive each other's device list
- changes.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # Create a room for them to coexist peacefully in
- new_room_id = self.helper.create_room_as(
- alice_user_id, is_public=True, tok=alice_access_token
- )
- self.assertIsNotNone(new_room_id)
-
- # Have Bob join the room
- self.helper.invite(
- new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
- )
- self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
-
- # Now have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- self.sync_endpoint,
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that bob's incremental sync contains the updated device list.
- # If not, the client would only receive the device list update on the
- # *next* sync.
- bob_sync_channel.await_result()
- self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
-
- def test_not_receiving_local_device_list_changes(self) -> None:
- """Tests a local users DO NOT receive device updates from each other if they do not
- share a room.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # These users do not share a room. They are lonely.
-
- # Have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- self.sync_endpoint,
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that bob's incremental sync does not contain the updated device list.
- bob_sync_channel.await_result()
- self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertNotIn(
- alice_user_id, changed_device_lists, bob_sync_channel.json_body
- )
-
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
- test_device_id = "TESTDEVICE"
+ device_id = "TESTDEVICE"
# Register a user and login, creating a device
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey", device_id=device_id)
# Request an initial sync
- channel = self.make_request(
- "GET", self.sync_endpoint, access_token=alice_access_token
- )
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -872,19 +712,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
- f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
- access_token=alice_access_token,
+ f"/sync?since={next_batch}&timeout=30000",
+ access_token=self.tok,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
- f"devices/{test_device_id}",
+ f"devices/{device_id}",
{
"display_name": "freeze ray",
},
- access_token=alice_access_token,
+ access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -898,229 +738,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
- alice_user_id, device_list_changes, incremental_sync_channel.json_body
- )
-
-
-@parameterized_class(
- ("sync_endpoint", "experimental_features"),
- [
- ("/sync", {}),
- (
- "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
- # Enable sliding sync
- {"msc3575_enabled": True},
- ),
- ],
-)
-class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
- """
- Tests regarding device one time keys (`device_one_time_keys_count`) changes.
-
- Attributes:
- sync_endpoint: The endpoint under test to use for syncing.
- experimental_features: The experimental features homeserver config to use.
- """
-
- sync_endpoint: str
- experimental_features: JsonDict
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["experimental_features"] = self.experimental_features
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.e2e_keys_handler = hs.get_e2e_keys_handler()
-
- def test_no_device_one_time_keys(self) -> None:
- """
- Tests when no one time keys set, it still has the default `signed_curve25519` in
- `device_one_time_keys_count`
- """
- test_device_id = "TESTDEVICE"
-
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- # Request an initial sync
- channel = self.make_request(
- "GET", self.sync_endpoint, access_token=alice_access_token
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check for those one time key counts
- self.assertDictEqual(
- channel.json_body["device_one_time_keys_count"],
- # Note that "signed_curve25519" is always returned in key count responses
- # regardless of whether we uploaded any keys for it. This is necessary until
- # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
- {"signed_curve25519": 0},
- channel.json_body["device_one_time_keys_count"],
- )
-
- def test_returns_device_one_time_keys(self) -> None:
- """
- Tests that one time keys for the device/user are counted correctly in the `/sync`
- response
- """
- test_device_id = "TESTDEVICE"
-
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- # Upload one time keys for the user/device
- keys: JsonDict = {
- "alg1:k1": "key1",
- "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
- "alg2:k3": {"key": "key3"},
- }
- res = self.get_success(
- self.e2e_keys_handler.upload_keys_for_user(
- alice_user_id, test_device_id, {"one_time_keys": keys}
- )
- )
- # Note that "signed_curve25519" is always returned in key count responses
- # regardless of whether we uploaded any keys for it. This is necessary until
- # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
- self.assertDictEqual(
- res,
- {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
- )
-
- # Request an initial sync
- channel = self.make_request(
- "GET", self.sync_endpoint, access_token=alice_access_token
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check for those one time key counts
- self.assertDictEqual(
- channel.json_body["device_one_time_keys_count"],
- {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
- channel.json_body["device_one_time_keys_count"],
- )
-
-
-@parameterized_class(
- ("sync_endpoint", "experimental_features"),
- [
- ("/sync", {}),
- (
- "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
- # Enable sliding sync
- {"msc3575_enabled": True},
- ),
- ],
-)
-class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
- """
- Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
-
- Attributes:
- sync_endpoint: The endpoint under test to use for syncing.
- experimental_features: The experimental features homeserver config to use.
- """
-
- sync_endpoint: str
- experimental_features: JsonDict
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["experimental_features"] = self.experimental_features
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = self.hs.get_datastores().main
- self.e2e_keys_handler = hs.get_e2e_keys_handler()
-
- def test_no_device_unused_fallback_key(self) -> None:
- """
- Test when no unused fallback key is set, it just returns an empty list. The MSC
- says "The device_unused_fallback_key_types parameter must be present if the
- server supports fallback keys.",
- https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
- """
- test_device_id = "TESTDEVICE"
-
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- # Request an initial sync
- channel = self.make_request(
- "GET", self.sync_endpoint, access_token=alice_access_token
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check for those one time key counts
- self.assertListEqual(
- channel.json_body["device_unused_fallback_key_types"],
- [],
- channel.json_body["device_unused_fallback_key_types"],
- )
-
- def test_returns_device_one_time_keys(self) -> None:
- """
- Tests that device unused fallback key type is returned correctly in the `/sync`
- """
- test_device_id = "TESTDEVICE"
-
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- # We shouldn't have any unused fallback keys yet
- res = self.get_success(
- self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
- )
- self.assertEqual(res, [])
-
- # Upload a fallback key for the user/device
- self.get_success(
- self.e2e_keys_handler.upload_keys_for_user(
- alice_user_id,
- test_device_id,
- {"fallback_keys": {"alg1:k1": "fallback_key1"}},
- )
- )
- # We should now have an unused alg1 key
- fallback_res = self.get_success(
- self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
- )
- self.assertEqual(fallback_res, ["alg1"], fallback_res)
-
- # Request an initial sync
- channel = self.make_request(
- "GET", self.sync_endpoint, access_token=alice_access_token
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check for the unused fallback key types
- self.assertListEqual(
- channel.json_body["device_unused_fallback_key_types"],
- ["alg1"],
- channel.json_body["device_unused_fallback_key_types"],
+ self.user_id, device_list_changes, incremental_sync_channel.json_body
)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index d10df1a90f..22b2e34570 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index af1eecbb34..d7f479786d 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index 805c49b540..1a83e70f48 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index c4b15c5ae7..7038e42058 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index e43140720d..7bcbed246b 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -1,9 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -87,7 +84,8 @@ class RestHelper:
expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
- ) -> str: ...
+ ) -> str:
+ ...
@overload
def create_room_as(
@@ -99,7 +97,8 @@ class RestHelper:
expect_code: int = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
- ) -> Optional[str]: ...
+ ) -> Optional[str]:
+ ...
def create_room_as(
self,
@@ -170,16 +169,14 @@ class RestHelper:
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
- extra_data: Optional[dict] = None,
- ) -> JsonDict:
- return self.change_membership(
+ ) -> None:
+ self.change_membership(
room=room,
src=src,
targ=targ,
tok=tok,
membership=Membership.INVITE,
expect_code=expect_code,
- extra_data=extra_data,
)
def join(
@@ -191,8 +188,8 @@ class RestHelper:
appservice_user_id: Optional[str] = None,
expect_errcode: Optional[Codes] = None,
expect_additional_fields: Optional[dict] = None,
- ) -> JsonDict:
- return self.change_membership(
+ ) -> None:
+ self.change_membership(
room=room,
src=user,
targ=user,
@@ -244,8 +241,8 @@ class RestHelper:
user: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
- ) -> JsonDict:
- return self.change_membership(
+ ) -> None:
+ self.change_membership(
room=room,
src=user,
targ=user,
@@ -261,9 +258,9 @@ class RestHelper:
targ: str,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
- ) -> JsonDict:
+ ) -> None:
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
- return self.change_membership(
+ self.change_membership(
room=room,
src=src,
targ=targ,
@@ -284,7 +281,7 @@ class RestHelper:
expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None,
expect_additional_fields: Optional[dict] = None,
- ) -> JsonDict:
+ ) -> None:
"""
Send a membership state event into a room.
@@ -300,9 +297,6 @@ class RestHelper:
using an application service access token in `tok`.
expect_code: The expected HTTP response code
expect_errcode: The expected Matrix error code
-
- Returns:
- The JSON response
"""
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -330,12 +324,9 @@ class RestHelper:
data,
)
- assert (
- channel.code == expect_code
- ), "Expected: %d, got: %d, PUT %s -> resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
channel.code,
- path,
channel.result["body"],
)
@@ -364,7 +355,6 @@ class RestHelper:
)
self.auth_user_id = temp_id
- return channel.json_body
def send(
self,
@@ -374,7 +364,6 @@ class RestHelper:
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- type: str = "m.room.message",
) -> JsonDict:
if body is None:
body = "body_text_here"
@@ -383,7 +372,7 @@ class RestHelper:
return self.send_event(
room_id,
- type,
+ "m.room.message",
content,
txn_id,
tok,
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 21e12b2a2f..c404b85ec6 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py
index 72205c6bb3..b0fbdbfcf4 100644
--- a/tests/rest/media/test_domain_blocking.py
+++ b/tests/rest/media/test_domain_blocking.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -44,13 +43,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
-
- media_storage = hs.get_media_repository().media_storage
-
- ctx = media_storage.store_into_file(file_info)
- (f, fname) = self.get_success(ctx.__aenter__())
- f.write(SMALL_PNG)
- self.get_success(ctx.__aexit__(None, None, None))
+ with hs.get_media_repository().media_storage.store_into_file(file_info) as (
+ f,
+ fname,
+ finish,
+ ):
+ f.write(SMALL_PNG)
+ self.get_success(finish())
self.get_success(
self.store.store_cached_remote_media(
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index a96f0e7fca..0b952e185d 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/rest/synapse/__init__.py b/tests/rest/synapse/__init__.py
deleted file mode 100644
index e5138f67e1..0000000000
--- a/tests/rest/synapse/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
diff --git a/tests/rest/synapse/client/__init__.py b/tests/rest/synapse/client/__init__.py
deleted file mode 100644
index e5138f67e1..0000000000
--- a/tests/rest/synapse/client/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
diff --git a/tests/rest/synapse/client/test_federation_whitelist.py b/tests/rest/synapse/client/test_federation_whitelist.py
deleted file mode 100644
index f0067a8f2b..0000000000
--- a/tests/rest/synapse/client/test_federation_whitelist.py
+++ /dev/null
@@ -1,119 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2024 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-
-from typing import Dict
-
-from twisted.web.resource import Resource
-
-from synapse.rest import admin
-from synapse.rest.client import login
-from synapse.rest.synapse.client import build_synapse_client_resource_tree
-
-from tests import unittest
-
-
-class FederationWhitelistTests(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- ]
-
- def create_resource_dict(self) -> Dict[str, Resource]:
- base = super().create_resource_dict()
- base.update(build_synapse_client_resource_tree(self.hs))
- return base
-
- def test_default(self) -> None:
- "If the config option is not enabled, the endpoint should 404"
- channel = self.make_request(
- "GET", "/_synapse/client/v1/config/federation_whitelist", shorthand=False
- )
-
- self.assertEqual(channel.code, 404)
-
- @unittest.override_config({"federation_whitelist_endpoint_enabled": True})
- def test_no_auth(self) -> None:
- "Endpoint requires auth when enabled"
-
- channel = self.make_request(
- "GET", "/_synapse/client/v1/config/federation_whitelist", shorthand=False
- )
-
- self.assertEqual(channel.code, 401)
-
- @unittest.override_config({"federation_whitelist_endpoint_enabled": True})
- def test_no_whitelist(self) -> None:
- "Test when there is no whitelist configured"
-
- self.register_user("user", "password")
- tok = self.login("user", "password")
-
- channel = self.make_request(
- "GET",
- "/_synapse/client/v1/config/federation_whitelist",
- shorthand=False,
- access_token=tok,
- )
-
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"whitelist_enabled": False, "whitelist": []}
- )
-
- @unittest.override_config(
- {
- "federation_whitelist_endpoint_enabled": True,
- "federation_domain_whitelist": ["example.com"],
- }
- )
- def test_whitelist(self) -> None:
- "Test when there is a whitelist configured"
-
- self.register_user("user", "password")
- tok = self.login("user", "password")
-
- channel = self.make_request(
- "GET",
- "/_synapse/client/v1/config/federation_whitelist",
- shorthand=False,
- access_token=tok,
- )
-
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"whitelist_enabled": True, "whitelist": ["example.com"]}
- )
-
- @unittest.override_config(
- {
- "federation_whitelist_endpoint_enabled": True,
- "federation_domain_whitelist": ["example.com", "example.com"],
- }
- )
- def test_whitelist_no_duplicates(self) -> None:
- "Test when there is a whitelist configured with duplicates, no duplicates are returned"
-
- self.register_user("user", "password")
- tok = self.login("user", "password")
-
- channel = self.make_request(
- "GET",
- "/_synapse/client/v1/config/federation_whitelist",
- shorthand=False,
- access_token=tok,
- )
-
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"whitelist_enabled": True, "whitelist": ["example.com"]}
- )
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index bdbfce796a..5c6fd57df6 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/server.py b/tests/server.py
index 3e377585ce..f94d0e4397 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -47,7 +46,7 @@ from typing import (
Union,
cast,
)
-from unittest.mock import Mock, patch
+from unittest.mock import Mock
import attr
from incremental import Version
@@ -55,7 +54,6 @@ from typing_extensions import ParamSpec
from zope.interface import implementer
import twisted
-from twisted.enterprise import adbapi
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@@ -85,7 +83,6 @@ from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
-from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.events.presence_router import load_legacy_presence_router
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
@@ -96,8 +93,8 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
)
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.database import LoggingDatabaseConnection, make_pool
-from synapse.storage.engines import BaseDatabaseEngine, create_engine
+from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
@@ -198,35 +195,17 @@ class FakeChannel:
def headers(self) -> Headers:
if not self.result:
raise Exception("No result yet.")
-
- h = self.result["headers"]
- assert isinstance(h, Headers)
+ h = Headers()
+ for i in self.result["headers"]:
+ h.addRawHeader(*i)
return h
def writeHeaders(
- self,
- version: bytes,
- code: bytes,
- reason: bytes,
- headers: Union[Headers, List[Tuple[bytes, bytes]]],
+ self, version: bytes, code: bytes, reason: bytes, headers: Headers
) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
-
- if isinstance(headers, list):
- # Support prior to Twisted 24.7.0rc1
- new_headers = Headers()
- for k, v in headers:
- assert isinstance(k, bytes), f"key is not of type bytes: {k!r}"
- assert isinstance(v, bytes), f"value is not of type bytes: {v!r}"
- new_headers.addRawHeader(k, v)
- headers = new_headers
-
- assert isinstance(
- headers, Headers
- ), f"headers are of the wrong type: {headers!r}"
-
self.result["headers"] = headers
def write(self, data: bytes) -> None:
@@ -307,6 +286,10 @@ class FakeChannel:
self._reactor.run()
while not self.is_finished():
+ # If there's a producer, tell it to resume producing so we get content
+ if self._producer:
+ self._producer.resumeProducing()
+
if self._reactor.seconds() > end_time:
raise TimedOutException("Timed out waiting for request to finish.")
@@ -366,7 +349,6 @@ def make_request(
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
- content_type: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
@@ -389,8 +371,6 @@ def make_request(
with the usual REST API path, if it doesn't contain it.
federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
- content_type: The content-type to use for the request. If not set then will default to
- application/json unless content_is_form is true.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
await_result: whether to wait for the request to complete rendering. If true,
@@ -454,9 +434,7 @@ def make_request(
)
if content:
- if content_type is not None:
- req.requestHeaders.addRawHeader(b"Content-Type", content_type)
- elif content_is_form:
+ if content_is_form:
req.requestHeaders.addRawHeader(
b"Content-Type", b"application/x-www-form-urlencoded"
)
@@ -691,53 +669,6 @@ def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
)
-def make_fake_db_pool(
- reactor: ISynapseReactor,
- db_config: DatabaseConnectionConfig,
- engine: BaseDatabaseEngine,
-) -> adbapi.ConnectionPool:
- """Wrapper for `make_pool` which builds a pool which runs db queries synchronously.
-
- For more deterministic testing, we don't use a regular db connection pool: instead
- we run all db queries synchronously on the test reactor's main thread. This function
- is a drop-in replacement for the normal `make_pool` which builds such a connection
- pool.
- """
- pool = make_pool(reactor, db_config, engine)
-
- def runWithConnection(
- func: Callable[..., R], *args: Any, **kwargs: Any
- ) -> Awaitable[R]:
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs,
- )
-
- def runInteraction(
- desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
- ) -> Awaitable[R]:
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- desc,
- func,
- *args,
- **kwargs,
- )
-
- pool.runWithConnection = runWithConnection # type: ignore[method-assign]
- pool.runInteraction = runInteraction # type: ignore[assignment]
- # Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(reactor)
- pool.running = True
- return pool
-
-
class ThreadPool:
"""
Threadless thread pool.
@@ -774,6 +705,52 @@ class ThreadPool:
return d
+def _make_test_homeserver_synchronous(server: HomeServer) -> None:
+ """
+ Make the given test homeserver's database interactions synchronous.
+ """
+
+ clock = server.get_clock()
+
+ for database in server.get_datastores().databases:
+ pool = database._db_pool
+
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs,
+ )
+
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ desc,
+ func,
+ *args,
+ **kwargs,
+ )
+
+ pool.runWithConnection = runWithConnection # type: ignore[method-assign]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
+ # Replace the thread pool with a threadless 'thread' pool
+ pool.threadpool = ThreadPool(clock._reactor)
+ pool.running = True
+
+ # We've just changed the Databases to run DB transactions on the same
+ # thread, so we need to disable the dedicated thread behaviour.
+ server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
+
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
@@ -960,7 +937,7 @@ def connect_client(
class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver(
@@ -1089,14 +1066,7 @@ def setup_test_homeserver(
# Mock TLS
hs.tls_server_context_factory = Mock()
- # Patch `make_pool` before initialising the database, to make database transactions
- # synchronous for testing.
- with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool):
- hs.setup()
-
- # Since we've changed the databases to run DB transactions on the same
- # thread, we need to stop the event fetcher hogging that one thread.
- hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+ hs.setup()
if USE_POSTGRES_FOR_TESTS:
database_pool = hs.get_datastores().databases[0]
@@ -1166,16 +1136,14 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
+ # Make the threadpool and database transactions synchronous for testing.
+ _make_test_homeserver_synchronous(hs)
+
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
for module, module_config in hs.config.modules.loaded_modules:
module(config=module_config, api=module_api)
- if hs.config.auto_accept_invites.enabled:
- # Start the local auto_accept_invites module.
- m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
- logger.info("Loaded local module %s", m)
-
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
diff --git a/tests/storage/databases/__init__.py b/tests/storage/databases/__init__.py
index 7e89a998b5..3d833a2e44 100644
--- a/tests/storage/databases/__init__.py
+++ b/tests/storage/databases/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/__init__.py b/tests/storage/databases/main/__init__.py
index 7e89a998b5..3d833a2e44 100644
--- a/tests/storage/databases/main/__init__.py
+++ b/tests/storage/databases/main/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py
index 3900a35290..e493588b5d 100644
--- a/tests/storage/databases/main/test_cache.py
+++ b/tests/storage/databases/main/test_cache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index f89f96fdcb..bb1daeacb7 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_end_to_end_keys.py b/tests/storage/databases/main/test_end_to_end_keys.py
index 1ed1d01cea..fd6b3e9f98 100644
--- a/tests/storage/databases/main/test_end_to_end_keys.py
+++ b/tests/storage/databases/main/test_end_to_end_keys.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index fd1f5e7fd5..caa5752032 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 1df71e723e..66c8864b1a 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index da2ec26421..e000394db2 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 88a5aa8cb1..232aad5b79 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 506d981ce6..b3dc4fe848 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 2859bcf4bd..42fbc4db60 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 10533f45d7..baa55ca862 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index b28db6a4ad..a83dfd1717 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 9420d03841..d0176309cd 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index d5b9996284..058c6caf90 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -337,15 +336,15 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly.
"""
- self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["1"] = (
- 100000
- )
- self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["2"] = (
- 200000
- )
- self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["3"] = (
- 300000
- )
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "1"
+ ] = 100000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "2"
+ ] = 200000
+ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+ "3"
+ ] = 300000
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 13f78ee2d2..91c76e82ff 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 8af0d6265b..550ca26412 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ba01b038ab..cb8cb06c84 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -36,14 +35,6 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def default_config(self) -> JsonDict:
- config = super().default_config()
-
- # We 'enable' federation otherwise `get_device_updates_by_remote` will
- # throw an exception.
- config["federation_sender_instances"] = ["master"]
- return config
-
def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index f1602fdc86..7be9ee102a 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 931d37e85a..40d2f3d88c 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index bd594d3c1f..44581a4cc0 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c4e216c308..e3d592e542 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -21,8 +20,6 @@
from typing import Dict, List, Set, Tuple, cast
-from parameterized import parameterized
-
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
@@ -47,8 +44,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
self.store = hs.get_datastores().main
self._next_stream_ordering = 1
- @parameterized.expand([(False,), (True,)])
- def test_simple(self, batched: bool) -> None:
+ def test_simple(self) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@@ -56,7 +52,6 @@ class EventChainStoreTestCase(HomeserverTestCase):
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
- charlie = "@charlie:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
@@ -195,26 +190,6 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
)
- charlie_invite = self.get_success(
- event_factory.for_room_version(
- RoomVersions.V6,
- {
- "type": EventTypes.Member,
- "state_key": charlie,
- "sender": alice,
- "room_id": room_id,
- "content": {"tag": "charlie_invite"},
- },
- ).build(
- prev_event_ids=[],
- auth_event_ids=[
- create.event_id,
- alice_join2.event_id,
- power_2.event_id,
- ],
- )
- )
-
events = [
create,
bob_join,
@@ -224,41 +199,33 @@ class EventChainStoreTestCase(HomeserverTestCase):
bob_join_2,
power_2,
alice_join2,
- charlie_invite,
]
expected_links = [
(bob_join, create),
+ (power, create),
(power, bob_join),
+ (alice_invite, create),
(alice_invite, power),
+ (alice_invite, bob_join),
(bob_join_2, power),
(alice_join2, power_2),
- (charlie_invite, alice_join2),
]
- # We either persist as a batch or one-by-one depending on test
- # parameter.
- if batched:
- self.persist(events)
- else:
- for event in events:
- self.persist([event])
-
+ self.persist(events)
chain_map, link_map = self.fetch_chains(events)
# Check that the expected links and only the expected links have been
# added.
- event_map = {e.event_id: e for e in events}
- reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
- self.maxDiff = None
- self.assertCountEqual(
- expected_links,
- [
- (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
- for s1, s2, t1, t2 in link_map.get_additions()
- ],
- )
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
# Test that everything can reach the create event, but the create event
# can't reach anything.
@@ -400,23 +367,24 @@ class EventChainStoreTestCase(HomeserverTestCase):
expected_links = [
(bob_join, create),
+ (power, create),
(power, bob_join),
+ (alice_invite, create),
(alice_invite, power),
+ (alice_invite, bob_join),
]
# Check that the expected links and only the expected links have been
# added.
- event_map = {e.event_id: e for e in events}
- reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
- self.maxDiff = None
- self.assertCountEqual(
- expected_links,
- [
- (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
- for s1, s2, t1, t2 in link_map.get_additions()
- ],
- )
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
def persist(
self,
@@ -431,7 +399,6 @@ class EventChainStoreTestCase(HomeserverTestCase):
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
- e.internal_metadata.instance_name = self.hs.get_instance_name()
self._next_stream_ordering += 1
def _persist(txn: LoggingTransaction) -> None:
@@ -447,14 +414,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
# Actually call the function that calculates the auth chain stuff.
- new_event_links = (
- persist_events_store.calculate_chain_cover_index_for_events_txn(
- txn, events[0].room_id, [e for e in events if e.is_state()]
- )
- )
- persist_events_store._persist_event_auth_chain_txn(
- txn, events, new_event_links
- )
+ persist_events_store._persist_event_auth_chain_txn(txn, events)
self.get_success(
persist_events_store.db_pool.runInteraction(
@@ -528,6 +488,8 @@ class LinkMapTestCase(unittest.TestCase):
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+ self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
@@ -536,30 +498,17 @@ class LinkMapTestCase(unittest.TestCase):
# Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1)))
- self.assertCountEqual(link_map.get_additions(), [])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
# Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3)))
- self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertTrue(link_map.add_link((2, 5), (1, 3)))
- self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+ self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
- def test_exists_path_from(self) -> None:
- "Check that `exists_path_from` can handle non-direct links"
- link_map = _LinkMap()
-
- link_map.add_link((1, 1), (2, 1), new=False)
- link_map.add_link((2, 1), (3, 1), new=False)
-
- self.assertTrue(link_map.exists_path_from((1, 4), (3, 1)))
- self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
-
- link_map.add_link((1, 5), (2, 3), new=False)
- link_map.add_link((2, 2), (3, 3), new=False)
-
- self.assertTrue(link_map.exists_path_from((1, 6), (3, 2)))
- self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
+ self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 088f0d24f9..0a6253e22c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,19 +365,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- events = [
- cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
- for event_id in AUTH_GRAPH
- ]
- new_event_links = (
- self.persist_events.calculate_chain_cover_index_for_events_txn(
- txn, room_id, [e for e in events if e.is_state()]
- )
- )
self.persist_events._persist_event_auth_chain_txn(
txn,
- events,
- new_event_links,
+ [
+ cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+ for event_id in AUTH_GRAPH
+ ],
)
self.get_success(
@@ -551,9 +544,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
rooms.
"""
- # We allow partial covers for this test
- self.hs.get_datastores().main.tests_allow_no_chain_cover_index = True
-
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -638,20 +628,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
- events = [
- cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
- for event_id in auth_graph
- if event_id != "b"
- ]
- new_event_links = (
- self.persist_events.calculate_chain_cover_index_for_events_txn(
- txn, room_id, [e for e in events if e.is_state()]
- )
- )
self.persist_events._persist_event_auth_chain_txn(
txn,
- events,
- new_event_links,
+ [
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+ for event_id in auth_graph
+ if event_id != "b"
+ ],
)
# Now we insert the event 'B' without a chain cover, by temporarily
@@ -664,14 +647,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
- events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
- new_event_links = (
- self.persist_events.calculate_chain_cover_index_for_events_txn(
- txn, room_id, [e for e in events if e.is_state()]
- )
- )
self.persist_events._persist_event_auth_chain_txn(
- txn, events, new_event_links
+ txn,
+ [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
self.store.db_pool.simple_update_txn(
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 3f7ee86498..02f7e6cd39 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 233066bf82..9639d49d5f 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 0a7c4c9421..eb37700192 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 12b89cecb6..aef52f131e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -18,7 +17,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import Dict, List, Optional
+from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -28,55 +27,177 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.storage.util.sequence import (
- LocalSequenceGenerator,
- PostgresSequenceGenerator,
- SequenceGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
-class MultiWriterIdGeneratorBase(HomeserverTestCase):
- positive: bool = True
- tables: List[str] = ["foobar"]
-
+class StreamIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
- self.instances: Dict[str, MultiWriterIdGenerator] = {}
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
- if USE_POSTGRES_FOR_TESTS:
- self.seq_gen: SequenceGenerator = PostgresSequenceGenerator("foobar_seq")
- else:
- self.seq_gen = LocalSequenceGenerator(lambda _: 0)
-
def _setup_db(self, txn: LoggingTransaction) -> None:
- if USE_POSTGRES_FOR_TESTS:
- txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+ txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
- for table in self.tables:
- txn.execute(
- """
- CREATE TABLE %s (
- stream_id BIGINT NOT NULL,
- instance_name TEXT NOT NULL,
- data TEXT
- );
- """
- % (table,)
+ def _create_id_generator(self) -> StreamIdGenerator:
+ def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
+ return StreamIdGenerator(
+ db_conn=conn,
+ notifier=self.hs.get_replication_notifier(),
+ table="foobar",
+ column="stream_id",
)
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+ def test_initial_value(self) -> None:
+ """Check that we read the current token from the DB."""
+ id_gen = self._create_id_generator()
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ def test_single_gen_next(self) -> None:
+ """Check that we correctly increment the current token from the DB."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ async with id_gen.get_next() as next_id:
+ # We haven't persisted `next_id` yet; current token is still 123
+ self.assertEqual(id_gen.get_current_token(), 123)
+ # But we did learn what the next value is
+ self.assertEqual(next_id, 124)
+
+ # Once the context manager closes we assume that the `next_id` has been
+ # written to the DB.
+ self.assertEqual(id_gen.get_current_token(), 124)
+
+ self.get_success(test_gen_next())
+
+ def test_multiple_gen_nexts(self) -> None:
+ """Check that we handle overlapping calls to gen_next sensibly."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request three new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ # None are persisted: current token unchanged.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Persist each in turn.
+ await ctx1.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 124)
+ await ctx2.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 125)
+ await ctx3.__aexit__(None, None, None)
+ self.assertEqual(id_gen.get_current_token(), 126)
+
+ self.get_success(test_gen_next())
+
+ def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
+ """Check that we handle overlapping calls to gen_next, even when their IDs
+ created and persisted in different orders."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request three new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ # None are persisted: current token unchanged.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Persist them in a different order, starting with 126 from ctx3.
+ await ctx3.__aexit__(None, None, None)
+ # We haven't persisted 124 from ctx1 yet---current token is still 123.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Now persist 124 from ctx1.
+ await ctx1.__aexit__(None, None, None)
+ # Current token is then 124, waiting for 125 to be persisted.
+ self.assertEqual(id_gen.get_current_token(), 124)
+
+ # Finally persist 125 from ctx2.
+ await ctx2.__aexit__(None, None, None)
+ # Current token is then 126 (skipping over 125).
+ self.assertEqual(id_gen.get_current_token(), 126)
+
+ self.get_success(test_gen_next())
+
+ def test_gen_next_while_still_waiting_for_persistence(self) -> None:
+ """Check that we handle overlapping calls to gen_next."""
+ id_gen = self._create_id_generator()
+
+ async def test_gen_next() -> None:
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
+ ctx3 = id_gen.get_next()
+
+ # Request two new stream IDs.
+ self.assertEqual(await ctx1.__aenter__(), 124)
+ self.assertEqual(await ctx2.__aenter__(), 125)
+
+ # Persist ctx2 first.
+ await ctx2.__aexit__(None, None, None)
+ # Still waiting on ctx1's ID to be persisted.
+ self.assertEqual(id_gen.get_current_token(), 123)
+
+ # Now request a third stream ID. It should be 126 (the smallest ID that
+ # we've not yet handed out.)
+ self.assertEqual(await ctx3.__aenter__(), 126)
+
+ self.get_success(test_gen_next())
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn: LoggingTransaction) -> None:
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
def _create_id_generator(
- self,
- instance_name: str = "master",
- writers: Optional[List[str]] = None,
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
@@ -85,98 +206,58 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
- tables=[(table, "instance_name", "stream_id") for table in self.tables],
+ tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
writers=writers or ["master"],
- positive=self.positive,
)
- self.instances[instance_name] = self.get_success_or_raise(
- self.db_pool.runWithConnection(_create)
- )
- return self.instances[instance_name]
-
- def _replicate(self, instance_name: str) -> None:
- """Similate a replication event for the given instance."""
-
- writer = self.instances[instance_name]
- token = writer.get_current_token_for_writer(instance_name)
- for generator in self.instances.values():
- if writer != generator:
- generator.advance(instance_name, token)
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
- def _replicate_all(self) -> None:
- """Similate a replication event for all instances."""
-
- for instance_name in self.instances:
- self._replicate(instance_name)
+ def _insert_rows(self, instance_name: str, number: int) -> None:
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
- def _insert_row(
- self, instance_name: str, stream_id: int, table: Optional[str] = None
- ) -> None:
- """Insert one row as the given instance with given stream_id."""
+ def _insert(txn: LoggingTransaction) -> None:
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
- if table is None:
- table = self.tables[0]
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
- factor = 1 if self.positive else -1
+ def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
+ """Insert one row as the given instance with given stream_id, updating
+ the postgres sequence position to match.
+ """
def _insert(txn: LoggingTransaction) -> None:
txn.execute(
- "INSERT INTO %s VALUES (?, ?)" % (table,),
+ "INSERT INTO foobar VALUES (?, ?)",
(
stream_id,
instance_name,
),
)
+ txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
""",
- (instance_name, stream_id * factor, stream_id * factor),
+ (instance_name, stream_id, stream_id),
)
- self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
-
- def _insert_rows(
- self,
- instance_name: str,
- number: int,
- table: Optional[str] = None,
- update_stream_table: bool = True,
- ) -> None:
- """Insert N rows as the given instance, inserting with stream IDs pulled
- from the postgres sequence.
- """
-
- if table is None:
- table = self.tables[0]
-
- factor = 1 if self.positive else -1
-
- def _insert(txn: LoggingTransaction) -> None:
- for _ in range(number):
- next_val = self.seq_gen.get_next_id_txn(txn)
- txn.execute(
- "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
- % (table,),
- (next_val, instance_name),
- )
-
- if update_stream_table:
- txn.execute(
- """
- INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
- ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
- """,
- (instance_name, next_val * factor, next_val * factor),
- )
-
- self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
-
+ self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
-class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
def test_empty(self) -> None:
"""Test an ID generator against an empty database gives sensible
current positions.
@@ -265,106 +346,137 @@ class MultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
- def test_get_next_txn(self) -> None:
- """Test that the `get_next_txn` function works correctly."""
+ def test_multi_instance(self) -> None:
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
- # Prefill table with 7 rows written by 'master'
- self._insert_rows("master", 7)
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- id_gen = self._create_id_generator()
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
- self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
- def _get_next_txn(txn: LoggingTransaction) -> None:
- stream_id = id_gen.get_next_txn(txn)
- self.assertEqual(stream_id, 8)
+ async def _get_next_async() -> None:
+ async with first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
- self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+ self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
- self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
+ self.get_success(_get_next_async())
- self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
- def test_restart_during_out_of_order_persistence(self) -> None:
- """Test that restarting a process while another process is writing out
- of order updates are handled correctly.
- """
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- # Prefill table with 7 rows written by 'master'
- self._insert_rows("master", 7)
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
- id_gen = self._create_id_generator()
+ async def _get_next_async2() -> None:
+ async with second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
- self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
- # Persist two rows at once
- ctx1 = id_gen.get_next()
- ctx2 = id_gen.get_next()
+ self.get_success(_get_next_async2())
- s1 = self.get_success(ctx1.__aenter__())
- s2 = self.get_success(ctx2.__aenter__())
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
- self.assertEqual(s1, 8)
- self.assertEqual(s2, 9)
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
- self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+ def test_multi_instance_empty_row(self) -> None:
+ """Test that reads and writes from multiple processes are handled
+ correctly, when one of the writers starts without any rows.
+ """
+ # Insert some rows for two out of three of the ID gens.
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
- # We finish persisting the second row before restart
- self.get_success(ctx2.__aexit__(None, None, None))
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+ third_id_gen = self._create_id_generator(
+ "third", writers=["first", "second", "third"]
+ )
- # We simulate a restart of another worker by just creating a new ID gen.
- id_gen_worker = self._create_id_generator("worker")
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
- # Restarted worker should not see the second persisted row
- self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
- self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
- # Now if we persist the first row then both instances should jump ahead
- # correctly.
- self.get_success(ctx1.__aexit__(None, None, None))
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
- self.assertEqual(id_gen.get_positions(), {"master": 9})
- id_gen_worker.advance("master", 9)
- self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+ async def _get_next_async() -> None:
+ async with third_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
-class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
+ self.get_success(_get_next_async())
- def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
- """Insert one row as the given instance with given stream_id, updating
- the postgres sequence position to match.
- """
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+ )
- def _insert(txn: LoggingTransaction) -> None:
- txn.execute(
- "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
- (
- stream_id,
- instance_name,
- ),
- )
+ def test_get_next_txn(self) -> None:
+ """Test that the `get_next_txn` function works correctly."""
- txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
- txn.execute(
- """
- INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
- ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
- """,
- (instance_name, stream_id, stream_id),
- )
+ id_gen = self._create_id_generator()
- self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn: LoggingTransaction) -> None:
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self) -> None:
"""Test that `get_persisted_upto_position` correctly tracks updates to
@@ -418,9 +530,7 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- # When the writer is created, it assumes its own position is the current head of
- # the sequence
- self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -437,118 +547,49 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
- def test_multi_instance(self) -> None:
- """Test that reads and writes from multiple processes are handled
- correctly.
+ def test_restart_during_out_of_order_persistence(self) -> None:
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
"""
- self._insert_rows("first", 3)
- first_id_gen = self._create_id_generator("first", writers=["first", "second"])
-
- self._insert_rows("second", 4)
- second_id_gen = self._create_id_generator("second", writers=["first", "second"])
-
- self._replicate_all()
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
- self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
-
- self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
- self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
-
- # Try allocating a new ID gen and check that we only see position
- # advanced after we leave the context manager.
-
- async def _get_next_async() -> None:
- async with first_id_gen.get_next() as stream_id:
- self.assertEqual(stream_id, 8)
-
- self.assertEqual(
- first_id_gen.get_positions(), {"first": 3, "second": 7}
- )
- self.assertEqual(
- second_id_gen.get_positions(), {"first": 3, "second": 7}
- )
- self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
-
- self.get_success(_get_next_async())
-
- self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
-
- # However the ID gen on the second instance won't have seen the update
- self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
-
- # ... but calling `get_next` on the second instance should give a unique
- # stream ID
-
- async def _get_next_async2() -> None:
- async with second_id_gen.get_next() as stream_id:
- self.assertEqual(stream_id, 9)
-
- self.assertEqual(
- second_id_gen.get_positions(), {"first": 3, "second": 7}
- )
-
- self.get_success(_get_next_async2())
-
- self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
- # If the second ID gen gets told about the first, it correctly updates
- second_id_gen.advance("first", 8)
- self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+ id_gen = self._create_id_generator()
- def test_multi_instance_empty_row(self) -> None:
- """Test that reads and writes from multiple processes are handled
- correctly, when one of the writers starts without any rows.
- """
- # Insert some rows for two out of three of the ID gens.
- self._insert_rows("first", 3)
- first_id_gen = self._create_id_generator(
- "first", writers=["first", "second", "third"]
- )
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- self._insert_rows("second", 4)
- second_id_gen = self._create_id_generator(
- "second", writers=["first", "second", "third"]
- )
- third_id_gen = self._create_id_generator(
- "third", writers=["first", "second", "third"]
- )
+ # Persist two rows at once
+ ctx1 = id_gen.get_next()
+ ctx2 = id_gen.get_next()
- self._replicate_all()
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
- self.assertEqual(
- first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
- )
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
- self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
- self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
- self.assertEqual(
- second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
- )
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
- self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
- self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- # Try allocating a new ID gen and check that we only see position
- # advanced after we leave the context manager.
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
- async def _get_next_async() -> None:
- async with third_id_gen.get_next() as stream_id:
- self.assertEqual(stream_id, 8)
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
- self.assertEqual(
- third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
- )
- self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
- self.get_success(_get_next_async())
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
- self.assertEqual(
- third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
- )
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
def test_writer_config_change(self) -> None:
"""Test that changing the writer config correctly works."""
@@ -598,7 +639,7 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
def test_sequence_consistency(self) -> None:
- """Test that we correct the sequence if the table and sequence diverges."""
+ """Test that we error out if the table and sequence diverges."""
# Prefill with some rows
self._insert_row_with_id("master", 3)
@@ -609,23 +650,16 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.get_success(self.db_pool.runInteraction("_insert", _insert))
- # Creating the ID gen should now fix the inconsistency
- id_gen = self._create_id_generator()
-
- async def _get_next_async() -> None:
- async with id_gen.get_next() as stream_id:
- self.assertEqual(stream_id, 27)
-
- self.get_success(_get_next_async())
+ # Creating the ID gen should error
+ with self.assertRaises(IncorrectDatabaseSetup):
+ self._create_id_generator("first")
def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3)
- first_id_gen = self._create_id_generator("first", writers=["first", "second"])
-
self._insert_rows("second", 4)
- second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self._replicate_all()
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
@@ -638,17 +672,15 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
token when there are no writes.
"""
self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
-
- self._insert_rows("second", 4)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
- self._replicate_all()
-
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7)
@@ -687,13 +719,68 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(second_id_gen.get_current_token(), 7)
-class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- positive = False
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn: LoggingTransaction) -> None:
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
+ ) -> MultiWriterIdGenerator:
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
+ stream_name="test_stream",
+ instance_name=instance_name,
+ tables=[("foobar", "instance_name", "stream_id")],
+ sequence_name="foobar_seq",
+ writers=writers or ["master"],
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int) -> None:
+ """Insert one row as the given instance with given stream_id."""
+
+ def _insert(txn: LoggingTransaction) -> None:
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)",
+ (
+ stream_id,
+ instance_name,
+ ),
+ )
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
@@ -739,7 +826,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
- self._replicate("first")
+ id_gen_2.advance("first", stream_id)
self.get_success(_get_next_async())
@@ -751,7 +838,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
- self._replicate("second")
+ id_gen_1.advance("second", stream_id)
self.get_success(_get_next_async2())
@@ -761,26 +848,98 @@ class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
-class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- tables = ["foobar1", "foobar2"]
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.db_pool: DatabasePool = self.store.db_pool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn: LoggingTransaction) -> None:
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar1 (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ txn.execute(
+ """
+ CREATE TABLE foobar2 (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(
+ self, instance_name: str = "master", writers: Optional[List[str]] = None
+ ) -> MultiWriterIdGenerator:
+ def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ notifier=self.hs.get_replication_notifier(),
+ stream_name="test_stream",
+ instance_name=instance_name,
+ tables=[
+ ("foobar1", "instance_name", "stream_id"),
+ ("foobar2", "instance_name", "stream_id"),
+ ],
+ sequence_name="foobar_seq",
+ writers=writers or ["master"],
+ )
+
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+ def _insert_rows(
+ self,
+ table: str,
+ instance_name: str,
+ number: int,
+ update_stream_table: bool = True,
+ ) -> None:
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
+ def _insert(txn: LoggingTransaction) -> None:
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+ (instance_name,),
+ )
+ if update_stream_table:
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
- self._insert_rows("first", 3, table="foobar1")
- first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self._insert_rows("foobar1", "first", 3)
+ self._insert_rows("foobar2", "second", 3)
+ self._insert_rows("foobar2", "second", 1, update_stream_table=False)
- self._insert_rows("second", 3, table="foobar2")
- self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self._replicate_all()
-
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7b5774b8c1..3b5b3c404f 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Awesome Technologies Innovationslabor GmbH
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9df8ea4ee6..e57566854e 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 0b984c7ebc..f3a6759405 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 91be2ccb3e..cb459d6b03 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 14e3871dc1..b4277b5436 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -43,6 +42,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.assertEqual(
UserInfo(
+ # TODO(paul): Surely this field should be 'user_id', not 'name'
user_id=UserID.from_string(self.user_id),
is_admin=False,
is_guest=False,
@@ -56,7 +56,6 @@ class RegistrationStoreTestCase(HomeserverTestCase):
locked=False,
is_shadow_banned=False,
approved=True,
- suspended=False,
),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
index a7f7c840f3..5cc8dc263b 100644
--- a/tests/storage/test_relations.py
+++ b/tests/storage/test_relations.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 909aee043e..942692dd7b 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 34d6fdb71e..7759e4cf91 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 340642b7e7..1788ca2ab9 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -71,16 +70,17 @@ class EventSearchInsertionTest(HomeserverTestCase):
store.search_msgs([room_id], "hi bob", ["content.body"])
)
self.assertEqual(result.get("count"), 1)
- self.assertIn("hi", result.get("highlights"))
- self.assertIn("bob", result.get("highlights"))
+ if isinstance(store.database_engine, PostgresEngine):
+ self.assertIn("hi", result.get("highlights"))
+ self.assertIn("bob", result.get("highlights"))
# Check that search works for an unrelated message
result = self.get_success(
store.search_msgs([room_id], "another", ["content.body"])
)
self.assertEqual(result.get("count"), 1)
-
- self.assertIn("another", result.get("highlights"))
+ if isinstance(store.database_engine, PostgresEngine):
+ self.assertIn("another", result.get("highlights"))
# Check that search works for a search term that overlaps with the message
# containing a null byte and an unrelated message.
@@ -89,8 +89,8 @@ class EventSearchInsertionTest(HomeserverTestCase):
result = self.get_success(
store.search_msgs([room_id], "hi alice", ["content.body"])
)
-
- self.assertIn("alice", result.get("highlights"))
+ if isinstance(store.database_engine, PostgresEngine):
+ self.assertIn("alice", result.get("highlights"))
def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.
@@ -327,11 +327,9 @@ class MessageSearchTest(HomeserverTestCase):
self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
- (
- f"expected '{query}' to match '{self.PHRASE}'"
- if expect_to_contain
- else f"'{query}' unexpectedly matched '{self.PHRASE}'"
- ),
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'",
)
self.assertEqual(
len(result["results"]),
@@ -347,11 +345,9 @@ class MessageSearchTest(HomeserverTestCase):
self.assertEqual(
result["count"],
1 if expect_to_contain else 0,
- (
- f"expected '{query}' to match '{self.PHRASE}'"
- if expect_to_contain
- else f"'{query}' unexpectedly matched '{self.PHRASE}'"
- ),
+ f"expected '{query}' to match '{self.PHRASE}'"
+ if expect_to_contain
+ else f"'{query}' unexpectedly matched '{self.PHRASE}'",
)
self.assertEqual(
len(result["results"]),
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 418b556108..31fcc8f829 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -19,28 +17,20 @@
# [This file includes modifications made by New Vector Limited]
#
#
-import logging
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.room_versions import RoomVersions
-from synapse.rest import admin
+from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
-from synapse.rest.client import knock, login, room
+from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
-from synapse.storage.roommember import MemberSummary
from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
from tests.test_utils import event_injection
-from tests.unittest import skip_unless
-
-logger = logging.getLogger(__name__)
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -248,397 +238,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
-class RoomSummaryTestCase(unittest.HomeserverTestCase):
- """
- Test `/sync` room summary related logic like `get_room_summary(...)` and
- `extract_heroes_from_room_summary(...)`
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
-
- def _assert_member_summary(
- self,
- actual_member_summary: MemberSummary,
- expected_member_list: List[str],
- *,
- expected_member_count: Optional[int] = None,
- ) -> None:
- """
- Assert that the `MemberSummary` object has the expected members.
- """
- self.assertListEqual(
- [
- user_id
- for user_id, _membership_event_id in actual_member_summary.members
- ],
- expected_member_list,
- )
- self.assertEqual(
- actual_member_summary.count,
- (
- expected_member_count
- if expected_member_count is not None
- else len(expected_member_list)
- ),
- )
-
- def test_get_room_summary_membership(self) -> None:
- """
- Test that `get_room_summary(...)` gets every kind of membership when there
- aren't that many members in the room.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- _user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
- user5_id = self.register_user("user5", "pass")
- user5_tok = self.login(user5_id, "pass")
-
- # Setup a room (user1 is the creator and is joined to the room)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # User2 is banned
- self.helper.join(room_id, user2_id, tok=user2_tok)
- self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
-
- # User3 is invited by user1
- self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
-
- # User4 leaves
- self.helper.join(room_id, user4_id, tok=user4_tok)
- self.helper.leave(room_id, user4_id, tok=user4_tok)
-
- # User5 joins
- self.helper.join(room_id, user5_id, tok=user5_tok)
-
- room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
- empty_ms = MemberSummary([], 0)
-
- self._assert_member_summary(
- room_membership_summary.get(Membership.JOIN, empty_ms),
- [user1_id, user5_id],
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id]
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id]
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.BAN, empty_ms), [user2_id]
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.KNOCK, empty_ms),
- [
- # No one knocked
- ],
- )
-
- def test_get_room_summary_membership_order(self) -> None:
- """
- Test that `get_room_summary(...)` stacks our limit of 6 in this order: joins ->
- invites -> leave -> everything else (bans/knocks)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- _user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
- user5_id = self.register_user("user5", "pass")
- user5_tok = self.login(user5_id, "pass")
- user6_id = self.register_user("user6", "pass")
- user6_tok = self.login(user6_id, "pass")
- user7_id = self.register_user("user7", "pass")
- user7_tok = self.login(user7_id, "pass")
-
- # Setup the room (user1 is the creator and is joined to the room)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # We expect the order to be joins -> invites -> leave -> bans so setup the users
- # *NOT* in that same order to make sure we're actually sorting them.
-
- # User2 is banned
- self.helper.join(room_id, user2_id, tok=user2_tok)
- self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
-
- # User3 is invited by user1
- self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
-
- # User4 leaves
- self.helper.join(room_id, user4_id, tok=user4_tok)
- self.helper.leave(room_id, user4_id, tok=user4_tok)
-
- # User5, User6, User7 joins
- self.helper.join(room_id, user5_id, tok=user5_tok)
- self.helper.join(room_id, user6_id, tok=user6_tok)
- self.helper.join(room_id, user7_id, tok=user7_tok)
-
- room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
- empty_ms = MemberSummary([], 0)
-
- self._assert_member_summary(
- room_membership_summary.get(Membership.JOIN, empty_ms),
- [user1_id, user5_id, user6_id, user7_id],
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.INVITE, empty_ms), [user3_id]
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.LEAVE, empty_ms), [user4_id]
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.BAN, empty_ms),
- [
- # The banned user is not in the summary because the summary can only fit
- # 6 members and prefers everything else before bans
- #
- # user2_id
- ],
- # But we still see the count of banned users
- expected_member_count=1,
- )
- self._assert_member_summary(
- room_membership_summary.get(Membership.KNOCK, empty_ms),
- [
- # No one knocked
- ],
- )
-
- def test_extract_heroes_from_room_summary_excludes_self(self) -> None:
- """
- Test that `extract_heroes_from_room_summary(...)` does not include the user
- itself.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Setup the room (user1 is the creator and is joined to the room)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # User2 joins
- self.helper.join(room_id, user2_id, tok=user2_tok)
-
- room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
-
- # We first ask from the perspective of a random fake user
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me="@fakeuser"
- )
-
- # Make sure user1 is in the room (ensure our test setup is correct)
- self.assertListEqual(hero_user_ids, [user1_id, user2_id])
-
- # Now, we ask for the room summary from the perspective of user1
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me=user1_id
- )
-
- # User1 should not be included in the list of heroes because they are the one
- # asking
- self.assertListEqual(hero_user_ids, [user2_id])
-
- def test_extract_heroes_from_room_summary_first_five_joins(self) -> None:
- """
- Test that `extract_heroes_from_room_summary(...)` returns the first 5 joins.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
- user5_id = self.register_user("user5", "pass")
- user5_tok = self.login(user5_id, "pass")
- user6_id = self.register_user("user6", "pass")
- user6_tok = self.login(user6_id, "pass")
- user7_id = self.register_user("user7", "pass")
- user7_tok = self.login(user7_id, "pass")
-
- # Setup the room (user1 is the creator and is joined to the room)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # User2 -> User7 joins
- self.helper.join(room_id, user2_id, tok=user2_tok)
- self.helper.join(room_id, user3_id, tok=user3_tok)
- self.helper.join(room_id, user4_id, tok=user4_tok)
- self.helper.join(room_id, user5_id, tok=user5_tok)
- self.helper.join(room_id, user6_id, tok=user6_tok)
- self.helper.join(room_id, user7_id, tok=user7_tok)
-
- room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
-
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me="@fakuser"
- )
-
- # First 5 users to join the room
- self.assertListEqual(
- hero_user_ids, [user1_id, user2_id, user3_id, user4_id, user5_id]
- )
-
- def test_extract_heroes_from_room_summary_membership_order(self) -> None:
- """
- Test that `extract_heroes_from_room_summary(...)` prefers joins/invites over
- everything else.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- _user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- user4_tok = self.login(user4_id, "pass")
- user5_id = self.register_user("user5", "pass")
- user5_tok = self.login(user5_id, "pass")
-
- # Setup the room (user1 is the creator and is joined to the room)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # We expect the order to be joins -> invites -> leave -> bans so setup the users
- # *NOT* in that same order to make sure we're actually sorting them.
-
- # User2 is banned
- self.helper.join(room_id, user2_id, tok=user2_tok)
- self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
-
- # User3 is invited by user1
- self.helper.invite(room_id, targ=user3_id, tok=user1_tok)
-
- # User4 leaves
- self.helper.join(room_id, user4_id, tok=user4_tok)
- self.helper.leave(room_id, user4_id, tok=user4_tok)
-
- # User5 joins
- self.helper.join(room_id, user5_id, tok=user5_tok)
-
- room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
-
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me="@fakeuser"
- )
-
- # Prefer joins -> invites, over everything else
- self.assertListEqual(
- hero_user_ids,
- [
- # The joins
- user1_id,
- user5_id,
- # The invites
- user3_id,
- ],
- )
-
- @skip_unless(
- False,
- "Test is not possible because when everyone leaves the room, "
- + "the server is `no_longer_in_room` and we don't have any `current_state_events` to query",
- )
- def test_extract_heroes_from_room_summary_fallback_leave_ban(self) -> None:
- """
- Test that `extract_heroes_from_room_summary(...)` falls back to leave/ban if
- there aren't any joins/invites.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- user3_tok = self.login(user3_id, "pass")
-
- # Setup the room (user1 is the creator and is joined to the room)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # User2 is banned
- self.helper.join(room_id, user2_id, tok=user2_tok)
- self.helper.ban(room_id, src=user1_id, targ=user2_id, tok=user1_tok)
-
- # User3 leaves
- self.helper.join(room_id, user3_id, tok=user3_tok)
- self.helper.leave(room_id, user3_id, tok=user3_tok)
-
- # User1 leaves (we're doing this last because they're the room creator)
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- room_membership_summary = self.get_success(self.store.get_room_summary(room_id))
-
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me="@fakeuser"
- )
-
- # Fallback to people who left -> banned
- self.assertListEqual(
- hero_user_ids,
- [user3_id, user1_id, user3_id],
- )
-
- def test_extract_heroes_from_room_summary_excludes_knocks(self) -> None:
- """
- People who knock on the room have (potentially) never been in the room before
- and are total outsiders. Plus the spec doesn't mention them at all for heroes.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Setup the knock room (user1 is the creator and is joined to the room)
- knock_room_id = self.helper.create_room_as(
- user1_id, tok=user1_tok, room_version=RoomVersions.V7.identifier
- )
- self.helper.send_state(
- knock_room_id,
- EventTypes.JoinRules,
- {"join_rule": JoinRules.KNOCK},
- tok=user1_tok,
- )
-
- # User2 knocks on the room
- knock_channel = self.make_request(
- "POST",
- "/_matrix/client/r0/knock/%s" % (knock_room_id,),
- b"{}",
- user2_tok,
- )
- self.assertEqual(knock_channel.code, 200, knock_channel.result)
-
- room_membership_summary = self.get_success(
- self.store.get_room_summary(knock_room_id)
- )
-
- hero_user_ids = extract_heroes_from_room_summary(
- room_membership_summary, me="@fakeuser"
- )
-
- # user1 is the creator and is joined to the room (should show up as a hero)
- # user2 is knocking on the room (should not show up as a hero)
- self.assertListEqual(
- hero_user_ids,
- [user1_id],
- )
-
-
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 48f8d1c340..f42a74ac61 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 9dea1af8ea..ab861ead93 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -19,36 +18,19 @@
#
#
-import logging
-from typing import List, Tuple
-from unittest.mock import AsyncMock, patch
-
-from immutabledict import immutabledict
+from typing import List
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.filtering import Filter
-from synapse.crypto.event_signing import add_hashes_and_signatures
-from synapse.events import FrozenEventV3
-from synapse.federation.federation_client import SendJoinResult
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
-from synapse.types import (
- JsonDict,
- PersistedEventPosition,
- RoomStreamToken,
- UserID,
- create_requester,
-)
+from synapse.types import JsonDict
from synapse.util import Clock
-from tests.test_utils.event_injection import create_event
-from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
-
-logger = logging.getLogger(__name__)
+from tests.unittest import HomeserverTestCase
class PaginationTestCase(HomeserverTestCase):
@@ -285,1170 +267,3 @@ class PaginationTestCase(HomeserverTestCase):
}
chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none])
-
-
-class GetLastEventInRoomBeforeStreamOrderingTestCase(HomeserverTestCase):
- """
- Test `get_last_event_pos_in_room_before_stream_ordering(...)`
- """
-
- servlets = [
- admin.register_servlets,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
-
- def _update_persisted_instance_name_for_event(
- self, event_id: str, instance_name: str
- ) -> None:
- """
- Update the `instance_name` that persisted the the event in the database.
- """
- return self.get_success(
- self.store.db_pool.simple_update_one(
- "events",
- keyvalues={"event_id": event_id},
- updatevalues={"instance_name": instance_name},
- )
- )
-
- def _send_event_on_instance(
- self, instance_name: str, room_id: str, access_token: str
- ) -> Tuple[JsonDict, PersistedEventPosition]:
- """
- Send an event in a room and mimic that it was persisted by a specific
- instance/worker.
- """
- event_response = self.helper.send(
- room_id, f"{instance_name} message", tok=access_token
- )
-
- self._update_persisted_instance_name_for_event(
- event_response["event_id"], instance_name
- )
-
- event_pos = self.get_success(
- self.store.get_position_for_event(event_response["event_id"])
- )
-
- return event_response, event_pos
-
- def test_before_room_created(self) -> None:
- """
- Test that no event is returned if we are using a token before the room was even created
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_room_token = self.event_sources.get_current_token()
-
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id,
- end_token=before_room_token.room_key,
- )
- )
-
- self.assertIsNone(last_event_result)
-
- def test_after_room_created(self) -> None:
- """
- Test that an event is returned if we are using a token after the room was created
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
-
- after_room_token = self.event_sources.get_current_token()
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id,
- end_token=after_room_token.room_key,
- )
- )
- assert last_event_result is not None
- last_event_id, _ = last_event_result
-
- self.assertIsNotNone(last_event_id)
-
- def test_activity_in_other_rooms(self) -> None:
- """
- Test to make sure that the last event in the room is returned even if the
- `stream_ordering` has advanced from activity in other rooms.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response = self.helper.send(room_id1, "target!", tok=user1_tok)
- # Create another room to advance the stream_ordering
- self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
-
- after_room_token = self.event_sources.get_current_token()
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id1,
- end_token=after_room_token.room_key,
- )
- )
- assert last_event_result is not None
- last_event_id, _ = last_event_result
-
- # Make sure it's the event we expect (which also means we know it's from the
- # correct room)
- self.assertEqual(last_event_id, event_response["event_id"])
-
- def test_activity_after_token_has_no_effect(self) -> None:
- """
- Test to make sure we return the last event before the token even if there is
- activity after it.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response = self.helper.send(room_id1, "target!", tok=user1_tok)
-
- after_room_token = self.event_sources.get_current_token()
-
- # Send some events after the token
- self.helper.send(room_id1, "after1", tok=user1_tok)
- self.helper.send(room_id1, "after2", tok=user1_tok)
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id1,
- end_token=after_room_token.room_key,
- )
- )
- assert last_event_result is not None
- last_event_id, _ = last_event_result
-
- # Make sure it's the last event before the token
- self.assertEqual(last_event_id, event_response["event_id"])
-
- def test_last_event_within_sharded_token(self) -> None:
- """
- Test to make sure we can find the last event that that is *within* the sharded
- token (a token that has an `instance_map` and looks like
- `m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`). We are specifically testing
- that we can find an event within the tokens minimum and instance
- `stream_ordering`.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response1, event_pos1 = self._send_event_on_instance(
- "worker1", room_id1, user1_tok
- )
- event_response2, event_pos2 = self._send_event_on_instance(
- "worker1", room_id1, user1_tok
- )
- event_response3, event_pos3 = self._send_event_on_instance(
- "worker1", room_id1, user1_tok
- )
-
- # Create another room to advance the `stream_ordering` on the same worker
- # so we can sandwich event3 in the middle of the token
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response4, event_pos4 = self._send_event_on_instance(
- "worker1", room_id2, user1_tok
- )
-
- # Assemble a token that encompasses event1 -> event4 on worker1
- end_token = RoomStreamToken(
- stream=event_pos2.stream,
- instance_map=immutabledict({"worker1": event_pos4.stream}),
- )
-
- # Send some events after the token
- self.helper.send(room_id1, "after1", tok=user1_tok)
- self.helper.send(room_id1, "after2", tok=user1_tok)
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id1,
- end_token=end_token,
- )
- )
- assert last_event_result is not None
- last_event_id, _ = last_event_result
-
- # Should find closest event before the token in room1
- self.assertEqual(
- last_event_id,
- event_response3["event_id"],
- f"We expected {event_response3['event_id']} but saw {last_event_id} which corresponds to "
- + str(
- {
- "event1": event_response1["event_id"],
- "event2": event_response2["event_id"],
- "event3": event_response3["event_id"],
- }
- ),
- )
-
- def test_last_event_before_sharded_token(self) -> None:
- """
- Test to make sure we can find the last event that is *before* the sharded token
- (a token that has an `instance_map` and looks like
- `m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response1, event_pos1 = self._send_event_on_instance(
- "worker1", room_id1, user1_tok
- )
- event_response2, event_pos2 = self._send_event_on_instance(
- "worker1", room_id1, user1_tok
- )
-
- # Create another room to advance the `stream_ordering` on the same worker
- room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response3, event_pos3 = self._send_event_on_instance(
- "worker1", room_id2, user1_tok
- )
- event_response4, event_pos4 = self._send_event_on_instance(
- "worker1", room_id2, user1_tok
- )
-
- # Assemble a token that encompasses event3 -> event4 on worker1
- end_token = RoomStreamToken(
- stream=event_pos3.stream,
- instance_map=immutabledict({"worker1": event_pos4.stream}),
- )
-
- # Send some events after the token
- self.helper.send(room_id1, "after1", tok=user1_tok)
- self.helper.send(room_id1, "after2", tok=user1_tok)
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id1,
- end_token=end_token,
- )
- )
- assert last_event_result is not None
- last_event_id, _ = last_event_result
-
- # Should find closest event before the token in room1
- self.assertEqual(
- last_event_id,
- event_response2["event_id"],
- f"We expected {event_response2['event_id']} but saw {last_event_id} which corresponds to "
- + str(
- {
- "event1": event_response1["event_id"],
- "event2": event_response2["event_id"],
- }
- ),
- )
-
- def test_restrict_event_types(self) -> None:
- """
- Test that we only consider given `event_types` when finding the last event
- before a token.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
- event_response = self.helper.send_event(
- room_id1,
- type="org.matrix.special_message",
- content={"body": "before1, target!"},
- tok=user1_tok,
- )
- self.helper.send(room_id1, "before2", tok=user1_tok)
-
- after_room_token = self.event_sources.get_current_token()
-
- # Send some events after the token
- self.helper.send_event(
- room_id1,
- type="org.matrix.special_message",
- content={"body": "after1"},
- tok=user1_tok,
- )
- self.helper.send(room_id1, "after2", tok=user1_tok)
-
- last_event_result = self.get_success(
- self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id=room_id1,
- end_token=after_room_token.room_key,
- event_types=["org.matrix.special_message"],
- )
- )
- assert last_event_result is not None
- last_event_id, _ = last_event_result
-
- # Make sure it's the last event before the token
- self.assertEqual(last_event_id, event_response["event_id"])
-
-
-class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
- """
- Test `get_current_state_delta_membership_changes_for_user(...)`
- """
-
- servlets = [
- admin.register_servlets,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
- self.state_handler = self.hs.get_state_handler()
- persistence = hs.get_storage_controllers().persistence
- assert persistence is not None
- self.persistence = persistence
-
- def test_returns_membership_events(self) -> None:
- """
- A basic test that a membership event in the token range is returned for the user.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_pos = self.get_success(
- self.store.get_position_for_event(join_response["event_id"])
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_room1_token.room_key,
- to_key=after_room1_token.room_key,
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=join_response["event_id"],
- event_pos=join_pos,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- )
- ],
- )
-
- def test_server_left_room_after_us(self) -> None:
- """
- Test that when probing over part of the DAG where the server left the room *after
- us*, we still see the join and leave changes.
-
- This is to make sure we play nicely with this behavior: When the server leaves a
- room, it will insert new rows with `event_id = null` into the
- `current_state_delta_stream` table for all current state.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "power_level_content_override": {
- "users": {
- user2_id: 100,
- # Allow user1 to send state in the room
- user1_id: 100,
- }
- }
- },
- )
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_pos1 = self.get_success(
- self.store.get_position_for_event(join_response1["event_id"])
- )
- # Make sure that random other non-member state that happens to have a `state_key`
- # matching the user ID doesn't mess with things.
- self.helper.send_state(
- room_id1,
- event_type="foobarbazdummy",
- state_key=user1_id,
- body={"foo": "bar"},
- tok=user1_tok,
- )
- # User1 should leave the room first
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- leave_pos1 = self.get_success(
- self.store.get_position_for_event(leave_response1["event_id"])
- )
-
- # User2 should also leave the room (everyone has left the room which means the
- # server is no longer in the room).
- self.helper.leave(room_id1, user2_id, tok=user2_tok)
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Get the membership changes for the user.
- #
- # At this point, the `current_state_delta_stream` table should look like the
- # following. When the server leaves a room, it will insert new rows with
- # `event_id = null` for all current state.
- #
- # | stream_id | room_id | type | state_key | event_id | prev_event_id |
- # |-----------|----------|-----------------------------|----------------|----------|---------------|
- # | 2 | !x:test | 'm.room.create' | '' | $xxx | None |
- # | 3 | !x:test | 'm.room.member' | '@user2:test' | $aaa | None |
- # | 4 | !x:test | 'm.room.history_visibility' | '' | $xxx | None |
- # | 4 | !x:test | 'm.room.join_rules' | '' | $xxx | None |
- # | 4 | !x:test | 'm.room.power_levels' | '' | $xxx | None |
- # | 7 | !x:test | 'm.room.member' | '@user1:test' | $ooo | None |
- # | 8 | !x:test | 'foobarbazdummy' | '@user1:test' | $xxx | None |
- # | 9 | !x:test | 'm.room.member' | '@user1:test' | $ppp | $ooo |
- # | 10 | !x:test | 'foobarbazdummy' | '@user1:test' | None | $xxx |
- # | 10 | !x:test | 'm.room.create' | '' | None | $xxx |
- # | 10 | !x:test | 'm.room.history_visibility' | '' | None | $xxx |
- # | 10 | !x:test | 'm.room.join_rules' | '' | None | $xxx |
- # | 10 | !x:test | 'm.room.member' | '@user1:test' | None | $ppp |
- # | 10 | !x:test | 'm.room.member' | '@user2:test' | None | $aaa |
- # | 10 | !x:test | 'm.room.power_levels' | | None | $xxx |
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_room1_token.room_key,
- to_key=after_room1_token.room_key,
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=join_response1["event_id"],
- event_pos=join_pos1,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- ),
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=leave_response1["event_id"],
- event_pos=leave_pos1,
- membership="leave",
- sender=user1_id,
- prev_event_id=join_response1["event_id"],
- prev_event_pos=join_pos1,
- prev_membership="join",
- prev_sender=user1_id,
- ),
- ],
- )
-
- def test_server_left_room_after_us_later(self) -> None:
- """
- Test when the user leaves the room, then sometime later, everyone else leaves
- the room, causing the server to leave the room, we shouldn't see any membership
- changes.
-
- This is to make sure we play nicely with this behavior: When the server leaves a
- room, it will insert new rows with `event_id = null` into the
- `current_state_delta_stream` table for all current state.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id1, user1_id, tok=user1_tok)
- # User1 should leave the room first
- self.helper.leave(room_id1, user1_id, tok=user1_tok)
-
- after_user1_leave_token = self.event_sources.get_current_token()
-
- # User2 should also leave the room (everyone has left the room which means the
- # server is no longer in the room).
- self.helper.leave(room_id1, user2_id, tok=user2_tok)
-
- after_server_leave_token = self.event_sources.get_current_token()
-
- # Join another room as user1 just to advance the stream_ordering and bust
- # `_membership_stream_cache`
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id2, user1_id, tok=user1_tok)
-
- # Get the membership changes for the user.
- #
- # At this point, the `current_state_delta_stream` table should look like the
- # following. When the server leaves a room, it will insert new rows with
- # `event_id = null` for all current state.
- #
- # TODO: Add DB rows to better see what's going on.
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=after_user1_leave_token.room_key,
- to_key=after_server_leave_token.room_key,
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [],
- )
-
- def test_we_cause_server_left_room(self) -> None:
- """
- Test that when probing over part of the DAG where the user leaves the room
- causing the server to leave the room (because we were the last local user in the
- room), we still see the join and leave changes.
-
- This is to make sure we play nicely with this behavior: When the server leaves a
- room, it will insert new rows with `event_id = null` into the
- `current_state_delta_stream` table for all current state.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- room_id1 = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "power_level_content_override": {
- "users": {
- user2_id: 100,
- # Allow user1 to send state in the room
- user1_id: 100,
- }
- }
- },
- )
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_pos1 = self.get_success(
- self.store.get_position_for_event(join_response1["event_id"])
- )
- # Make sure that random other non-member state that happens to have a `state_key`
- # matching the user ID doesn't mess with things.
- self.helper.send_state(
- room_id1,
- event_type="foobarbazdummy",
- state_key=user1_id,
- body={"foo": "bar"},
- tok=user1_tok,
- )
-
- # User2 should leave the room first.
- self.helper.leave(room_id1, user2_id, tok=user2_tok)
-
- # User1 (the person we're testing with) should also leave the room (everyone has
- # left the room which means the server is no longer in the room).
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- leave_pos1 = self.get_success(
- self.store.get_position_for_event(leave_response1["event_id"])
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Get the membership changes for the user.
- #
- # At this point, the `current_state_delta_stream` table should look like the
- # following. When the server leaves a room, it will insert new rows with
- # `event_id = null` for all current state.
- #
- # | stream_id | room_id | type | state_key | event_id | prev_event_id |
- # |-----------|-----------|-----------------------------|---------------|----------|---------------|
- # | 2 | '!x:test' | 'm.room.create' | '' | '$xxx' | None |
- # | 3 | '!x:test' | 'm.room.member' | '@user2:test' | '$aaa' | None |
- # | 4 | '!x:test' | 'm.room.history_visibility' | '' | '$xxx' | None |
- # | 4 | '!x:test' | 'm.room.join_rules' | '' | '$xxx' | None |
- # | 4 | '!x:test' | 'm.room.power_levels' | '' | '$xxx' | None |
- # | 7 | '!x:test' | 'm.room.member' | '@user1:test' | '$ooo' | None |
- # | 8 | '!x:test' | 'foobarbazdummy' | '@user1:test' | '$xxx' | None |
- # | 9 | '!x:test' | 'm.room.member' | '@user2:test' | '$bbb' | '$aaa' |
- # | 10 | '!x:test' | 'foobarbazdummy' | '@user1:test' | None | '$xxx' |
- # | 10 | '!x:test' | 'm.room.create' | '' | None | '$xxx' |
- # | 10 | '!x:test' | 'm.room.history_visibility' | '' | None | '$xxx' |
- # | 10 | '!x:test' | 'm.room.join_rules' | '' | None | '$xxx' |
- # | 10 | '!x:test' | 'm.room.member' | '@user1:test' | None | '$ooo' |
- # | 10 | '!x:test' | 'm.room.member' | '@user2:test' | None | '$bbb' |
- # | 10 | '!x:test' | 'm.room.power_levels' | '' | None | '$xxx' |
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_room1_token.room_key,
- to_key=after_room1_token.room_key,
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=join_response1["event_id"],
- event_pos=join_pos1,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- ),
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=None, # leave_response1["event_id"],
- event_pos=leave_pos1,
- membership="leave",
- sender=None, # user1_id,
- prev_event_id=join_response1["event_id"],
- prev_event_pos=join_pos1,
- prev_membership="join",
- prev_sender=user1_id,
- ),
- ],
- )
-
- def test_different_user_membership_persisted_in_same_batch(self) -> None:
- """
- Test batch of membership events from different users being processed at once.
- This will result in all of the memberships being stored in the
- `current_state_delta_stream` table with the same `stream_ordering` even though
- the individual events have different `stream_ordering`s.
- """
- user1_id = self.register_user("user1", "pass")
- _user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
- user3_id = self.register_user("user3", "pass")
- _user3_tok = self.login(user3_id, "pass")
- user4_id = self.register_user("user4", "pass")
- _user4_tok = self.login(user4_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- # User2 is just the designated person to create the room (we do this across the
- # tests to be consistent)
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
-
- # Persist the user1, user3, and user4 join events in the same batch so they all
- # end up in the `current_state_delta_stream` table with the same
- # stream_ordering.
- join_event3, join_event_context3 = self.get_success(
- create_event(
- self.hs,
- sender=user3_id,
- type=EventTypes.Member,
- state_key=user3_id,
- content={"membership": "join"},
- room_id=room_id1,
- )
- )
- # We want to put user1 in the middle of the batch. This way, regardless of the
- # implementation that inserts rows into current_state_delta_stream` (whether it
- # be minimum/maximum of stream position of the batch), we will still catch bugs.
- join_event1, join_event_context1 = self.get_success(
- create_event(
- self.hs,
- sender=user1_id,
- type=EventTypes.Member,
- state_key=user1_id,
- content={"membership": "join"},
- room_id=room_id1,
- )
- )
- join_event4, join_event_context4 = self.get_success(
- create_event(
- self.hs,
- sender=user4_id,
- type=EventTypes.Member,
- state_key=user4_id,
- content={"membership": "join"},
- room_id=room_id1,
- )
- )
- self.get_success(
- self.persistence.persist_events(
- [
- (join_event3, join_event_context3),
- (join_event1, join_event_context1),
- (join_event4, join_event_context4),
- ]
- )
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # Get the membership changes for the user.
- #
- # At this point, the `current_state_delta_stream` table should look like (notice
- # those three memberships at the end with `stream_id=7` because we persisted
- # them in the same batch):
- #
- # | stream_id | room_id | type | state_key | event_id | prev_event_id |
- # |-----------|-----------|----------------------------|------------------|----------|---------------|
- # | 2 | '!x:test' | 'm.room.create' | '' | '$xxx' | None |
- # | 3 | '!x:test' | 'm.room.member' | '@user2:test' | '$xxx' | None |
- # | 4 | '!x:test' | 'm.room.history_visibility'| '' | '$xxx' | None |
- # | 4 | '!x:test' | 'm.room.join_rules' | '' | '$xxx' | None |
- # | 4 | '!x:test' | 'm.room.power_levels' | '' | '$xxx' | None |
- # | 7 | '!x:test' | 'm.room.member' | '@user3:test' | '$xxx' | None |
- # | 7 | '!x:test' | 'm.room.member' | '@user1:test' | '$xxx' | None |
- # | 7 | '!x:test' | 'm.room.member' | '@user4:test' | '$xxx' | None |
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_room1_token.room_key,
- to_key=after_room1_token.room_key,
- )
- )
-
- join_pos3 = self.get_success(
- self.store.get_position_for_event(join_event3.event_id)
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=join_event1.event_id,
- # Ideally, this would be `join_pos1` (to match the `event_id`) but
- # when events are persisted in a batch, they are all stored in the
- # `current_state_delta_stream` table with the minimum
- # `stream_ordering` from the batch.
- event_pos=join_pos3,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- ),
- ],
- )
-
- def test_state_reset(self) -> None:
- """
- Test a state reset scenario where the user gets removed from the room (when
- there is no corresponding leave event)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_pos1 = self.get_success(
- self.store.get_position_for_event(join_response1["event_id"])
- )
-
- before_reset_token = self.event_sources.get_current_token()
-
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_delta_stream",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
- )
-
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
-
- after_reset_token = self.event_sources.get_current_token()
-
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_reset_token.room_key,
- to_key=after_reset_token.room_key,
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=None,
- event_pos=dummy_state_pos,
- membership="leave",
- sender=None, # user1_id,
- prev_event_id=join_response1["event_id"],
- prev_event_pos=join_pos1,
- prev_membership="join",
- prev_sender=user1_id,
- ),
- ],
- )
-
- def test_excluded_room_ids(self) -> None:
- """
- Test that the `excluded_room_ids` option excludes changes from the specified rooms.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_room1_token = self.event_sources.get_current_token()
-
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- join_pos1 = self.get_success(
- self.store.get_position_for_event(join_response1["event_id"])
- )
-
- room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
- join_pos2 = self.get_success(
- self.store.get_position_for_event(join_response2["event_id"])
- )
-
- after_room1_token = self.event_sources.get_current_token()
-
- # First test the the room is returned without the `excluded_room_ids` option
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_room1_token.room_key,
- to_key=after_room1_token.room_key,
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=join_response1["event_id"],
- event_pos=join_pos1,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- ),
- CurrentStateDeltaMembership(
- room_id=room_id2,
- event_id=join_response2["event_id"],
- event_pos=join_pos2,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- ),
- ],
- )
-
- # The test that `excluded_room_ids` excludes room2 as expected
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_room1_token.room_key,
- to_key=after_room1_token.room_key,
- excluded_room_ids=[room_id2],
- )
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=room_id1,
- event_id=join_response1["event_id"],
- event_pos=join_pos1,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- )
- ],
- )
-
-
-class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(
- FederatingHomeserverTestCase
-):
- """
- Test `get_current_state_delta_membership_changes_for_user(...)` when joining remote federated rooms.
- """
-
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
- self.room_member_handler = hs.get_room_member_handler()
-
- def test_remote_join(self) -> None:
- """
- Test remote join where the first rows in `current_state_delta_stream` will just
- be the state when you joined the remote room.
- """
- user1_id = self.register_user("user1", "pass")
- _user1_tok = self.login(user1_id, "pass")
-
- before_join_token = self.event_sources.get_current_token()
-
- intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
-
- # Remotely join a room on another homeserver.
- #
- # To do this we have to mock the responses from the remote homeserver. We also
- # patch out a bunch of event checks on our end.
- create_event_source = {
- "auth_events": [],
- "content": {
- "creator": f"@creator:{self.OTHER_SERVER_NAME}",
- "room_version": self.hs.config.server.default_room_version.identifier,
- },
- "depth": 0,
- "origin_server_ts": 0,
- "prev_events": [],
- "room_id": intially_unjoined_room_id,
- "sender": f"@creator:{self.OTHER_SERVER_NAME}",
- "state_key": "",
- "type": EventTypes.Create,
- }
- self.add_hashes_and_signatures_from_other_server(
- create_event_source,
- self.hs.config.server.default_room_version,
- )
- create_event = FrozenEventV3(
- create_event_source,
- self.hs.config.server.default_room_version,
- {},
- None,
- )
- creator_join_event_source = {
- "auth_events": [create_event.event_id],
- "content": {
- "membership": "join",
- },
- "depth": 1,
- "origin_server_ts": 1,
- "prev_events": [],
- "room_id": intially_unjoined_room_id,
- "sender": f"@creator:{self.OTHER_SERVER_NAME}",
- "state_key": f"@creator:{self.OTHER_SERVER_NAME}",
- "type": EventTypes.Member,
- }
- self.add_hashes_and_signatures_from_other_server(
- creator_join_event_source,
- self.hs.config.server.default_room_version,
- )
- creator_join_event = FrozenEventV3(
- creator_join_event_source,
- self.hs.config.server.default_room_version,
- {},
- None,
- )
-
- # Our local user is going to remote join the room
- join_event_source = {
- "auth_events": [create_event.event_id],
- "content": {"membership": "join"},
- "depth": 1,
- "origin_server_ts": 100,
- "prev_events": [creator_join_event.event_id],
- "sender": user1_id,
- "state_key": user1_id,
- "room_id": intially_unjoined_room_id,
- "type": EventTypes.Member,
- }
- add_hashes_and_signatures(
- self.hs.config.server.default_room_version,
- join_event_source,
- self.hs.hostname,
- self.hs.signing_key,
- )
- join_event = FrozenEventV3(
- join_event_source,
- self.hs.config.server.default_room_version,
- {},
- None,
- )
-
- mock_make_membership_event = AsyncMock(
- return_value=(
- self.OTHER_SERVER_NAME,
- join_event,
- self.hs.config.server.default_room_version,
- )
- )
- mock_send_join = AsyncMock(
- return_value=SendJoinResult(
- join_event,
- self.OTHER_SERVER_NAME,
- state=[create_event, creator_join_event],
- auth_chain=[create_event, creator_join_event],
- partial_state=False,
- servers_in_room=frozenset(),
- )
- )
-
- with patch.object(
- self.room_member_handler.federation_handler.federation_client,
- "make_membership_event",
- mock_make_membership_event,
- ), patch.object(
- self.room_member_handler.federation_handler.federation_client,
- "send_join",
- mock_send_join,
- ), patch(
- "synapse.event_auth._is_membership_change_allowed",
- return_value=None,
- ), patch(
- "synapse.handlers.federation_event.check_state_dependent_auth_rules",
- return_value=None,
- ):
- self.get_success(
- self.room_member_handler.update_membership(
- requester=create_requester(user1_id),
- target=UserID.from_string(user1_id),
- room_id=intially_unjoined_room_id,
- action=Membership.JOIN,
- remote_room_hosts=[self.OTHER_SERVER_NAME],
- )
- )
-
- after_join_token = self.event_sources.get_current_token()
-
- # Get the membership changes for the user.
- #
- # At this point, the `current_state_delta_stream` table should look like the
- # following. Notice that all of the events are at the same `stream_id` because
- # the current state starts out where we remotely joined:
- #
- # | stream_id | room_id | type | state_key | event_id | prev_event_id |
- # |-----------|------------------------------|-----------------|------------------------------|----------|---------------|
- # | 2 | '!example:other.example.com' | 'm.room.member' | '@user1:test' | '$xxx' | None |
- # | 2 | '!example:other.example.com' | 'm.room.create' | '' | '$xxx' | None |
- # | 2 | '!example:other.example.com' | 'm.room.member' | '@creator:other.example.com' | '$xxx' | None |
- membership_changes = self.get_success(
- self.store.get_current_state_delta_membership_changes_for_user(
- user1_id,
- from_key=before_join_token.room_key,
- to_key=after_join_token.room_key,
- )
- )
-
- join_pos = self.get_success(
- self.store.get_position_for_event(join_event.event_id)
- )
-
- # Let the whole diff show on failure
- self.maxDiff = None
- self.assertEqual(
- membership_changes,
- [
- CurrentStateDeltaMembership(
- room_id=intially_unjoined_room_id,
- event_id=join_event.event_id,
- event_pos=join_pos,
- membership="join",
- sender=user1_id,
- prev_event_id=None,
- prev_event_pos=None,
- prev_membership=None,
- prev_sender=None,
- ),
- ],
- )
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 5d58521810..4e99355fda 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index 4f652fc179..9912f1fa28 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index c26932069f..b4289b803a 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -711,10 +710,6 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
),
)
- self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"])
- self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"])
- self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"])
-
def test_regex_word_boundary_punctuation(self) -> None:
"""
Tests the behaviour of punctuation with the non-ICU tokeniser
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index 177da340e5..a39f44bc26 100644
--- a/tests/storage/test_user_filters.py
+++ b/tests/storage/test_user_filters.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/util/__init__.py b/tests/storage/util/__init__.py
index dab387a504..3d833a2e44 100644
--- a/tests/storage/util/__init__.py
+++ b/tests/storage/util/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 1e5663f137..a9e6d1f5c5 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 18792fdee3..55da82854a 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 6d1ae4c8d7..972284e55b 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 4e9adc0625..debba61b42 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 16206d5a97..b2b9cedf8f 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_server.py b/tests/test_server.py
index 9ff2589497..0910ea5f28 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -392,7 +392,8 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
)
self.assertEqual(channel.code, 301)
- location_headers = channel.headers.getRawHeaders(b"Location", [])
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/look/an/eagle"])
def test_redirect_exception_with_cookie(self) -> None:
@@ -414,10 +415,10 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
)
self.assertEqual(channel.code, 304)
- headers = channel.headers
- location_headers = headers.getRawHeaders(b"Location", [])
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/no/over/there"])
- cookies_headers = headers.getRawHeaders(b"Set-Cookie", [])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
def test_head_request(self) -> None:
diff --git a/tests/test_state.py b/tests/test_state.py
index 311a590693..ac3f6b100d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index c52f963a7e..21b44510bd 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_types.py b/tests/test_types.py
index 00adc65a5a..0d9cb17286 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -19,18 +18,9 @@
#
#
-from typing import Type
-from unittest import skipUnless
-
-from immutabledict import immutabledict
-from parameterized import parameterized_class
-
from synapse.api.errors import SynapseError
from synapse.types import (
- AbstractMultiWriterStreamToken,
- MultiWriterStreamToken,
RoomAlias,
- RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
@@ -38,7 +28,6 @@ from synapse.types import (
)
from tests import unittest
-from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase):
@@ -137,64 +126,3 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
-
-
-@parameterized_class(
- ("token_type",),
- [
- (MultiWriterStreamToken,),
- (RoomStreamToken,),
- ],
- class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
-)
-class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
- """Tests for the different types of multi writer tokens."""
-
- token_type: Type[AbstractMultiWriterStreamToken]
-
- def test_basic_token(self) -> None:
- """Test that a simple stream token can be serialized and unserialized"""
- store = self.hs.get_datastores().main
-
- token = self.token_type(stream=5)
-
- string_token = self.get_success(token.to_string(store))
-
- if isinstance(token, RoomStreamToken):
- self.assertEqual(string_token, "s5")
- else:
- self.assertEqual(string_token, "5")
-
- parsed_token = self.get_success(self.token_type.parse(store, string_token))
- self.assertEqual(parsed_token, token)
-
- @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
- def test_instance_map(self) -> None:
- """Test for stream token with instance map"""
- store = self.hs.get_datastores().main
-
- token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
-
- string_token = self.get_success(token.to_string(store))
- self.assertEqual(string_token, "m5~1.6")
-
- parsed_token = self.get_success(self.token_type.parse(store, string_token))
- self.assertEqual(parsed_token, token)
-
- def test_instance_map_assertion(self) -> None:
- """Test that we assert values in the instance map are greater than the
- min stream position"""
-
- with self.assertRaises(ValueError):
- self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
-
- with self.assertRaises(ValueError):
- self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
-
- def test_parse_bad_token(self) -> None:
- """Test that we can parse tokens produced by a bug in Synapse of the
- form `m5~`"""
- store = self.hs.get_datastores().main
-
- parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
- self.assertEqual(parsed_token, self.token_type(stream=5))
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 4ab42a02b9..0908e80c14 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 35b3245708..a82d25eaf2 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -125,15 +124,13 @@ async def mark_event_as_partial_state(
in this table).
"""
store = hs.get_datastores().main
- # Use the store helper to insert into the database so the caches are busted
- await store.store_partial_state_room(
- room_id=room_id,
- servers={hs.hostname},
- device_lists_stream_id=0,
- joined_via=hs.hostname,
+ await store.db_pool.simple_upsert(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={"room_id": room_id},
)
- # FIXME: Bust the cache
await store.db_pool.simple_insert(
table="partial_state_events",
values={
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index a0f39cb130..7ae1423de6 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 6c4be1c1f8..9b204b196b 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 89cbe4e54b..e51f72d65f 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,19 +21,13 @@ import logging
from typing import Optional
from unittest.mock import patch
-from synapse.api.constants import EventUnsignedContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.rest import admin
-from synapse.rest.client import login, room
-from synapse.server import HomeServer
-from synapse.types import create_requester
+from synapse.types import JsonDict, create_requester
from synapse.visibility import filter_events_for_client, filter_events_for_server
from tests import unittest
-from tests.test_utils.event_injection import inject_event, inject_member_event
-from tests.unittest import HomeserverTestCase
from tests.utils import create_room
logger = logging.getLogger(__name__)
@@ -62,31 +56,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
#
# before we do that, we persist some other events to act as state.
- self.get_success(
- inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined")
- )
+ self._inject_visibility("@admin:hs", "joined")
for i in range(10):
- self.get_success(
- inject_member_event(
- self.hs,
- TEST_ROOM_ID,
- "@resident%i:hs" % i,
- "join",
- )
- )
+ self._inject_room_member("@resident%i:hs" % i)
events_to_filter = []
for i in range(10):
- evt = self.get_success(
- inject_member_event(
- self.hs,
- TEST_ROOM_ID,
- "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"),
- "join",
- extra_content={"a": "b"},
- )
- )
+ user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
+ evt = self._inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
filtered = self.get_success(
@@ -112,19 +90,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def test_filter_outlier(self) -> None:
# outlier events must be returned, for the good of the collective federation
- self.get_success(
- inject_member_event(
- self.hs,
- TEST_ROOM_ID,
- "@resident:remote_hs",
- "join",
- )
- )
- self.get_success(
- inject_visibility_event(
- self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined"
- )
- )
+ self._inject_room_member("@resident:remote_hs")
+ self._inject_visibility("@resident:remote_hs", "joined")
outlier = self._inject_outlier()
self.assertEqual(
@@ -143,9 +110,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
# it should also work when there are other events in the list
- evt = self.get_success(
- inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
- )
+ evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
filter_events_for_server(
@@ -185,34 +150,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# change in the middle of them.
events_to_filter = []
- evt = self.get_success(
- inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
- )
+ evt = self._inject_message("@unerased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(
- inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
- )
+ evt = self._inject_message("@erased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(
- inject_member_event(
- self.hs,
- TEST_ROOM_ID,
- "@joiner:remote_hs",
- "join",
- )
- )
+ evt = self._inject_room_member("@joiner:remote_hs")
events_to_filter.append(evt)
- evt = self.get_success(
- inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
- )
+ evt = self._inject_message("@unerased:local_hs")
events_to_filter.append(evt)
- evt = self.get_success(
- inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
- )
+ evt = self._inject_message("@erased:local_hs")
events_to_filter.append(evt)
# the erasey user gets erased
@@ -250,142 +200,99 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
for i in (1, 4):
self.assertNotIn("body", filtered[i].content)
- def _inject_outlier(self) -> EventBase:
+ def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
+ content = {"history_visibility": visibility}
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
- "type": "m.room.member",
- "sender": "@test:user",
- "state_key": "@test:user",
+ "type": "m.room.history_visibility",
+ "sender": user_id,
+ "state_key": "",
"room_id": TEST_ROOM_ID,
- "content": {"membership": "join"},
+ "content": content,
},
)
- event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
- event.internal_metadata.outlier = True
- self.get_success(
- self._persistence.persist_event(
- event, EventContext.for_outlier(self._storage_controllers)
- )
+ event, unpersisted_context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+ self.get_success(self._persistence.persist_event(event, context))
return event
+ def _inject_room_member(
+ self,
+ user_id: str,
+ membership: str = "join",
+ extra_content: Optional[JsonDict] = None,
+ ) -> EventBase:
+ content = {"membership": membership}
+ content.update(extra_content or {})
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ },
+ )
-class FilterEventsForClientTestCase(HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def test_joined_history_visibility(self) -> None:
- # User joins and leaves room. Should be able to see the join and leave,
- # and messages sent between the two, but not before or after.
+ event, unpersisted_context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+ context = self.get_success(unpersisted_context.persist(event))
- self.register_user("resident", "p1")
- resident_token = self.login("resident", "p1")
- room_id = self.helper.create_room_as("resident", tok=resident_token)
+ self.get_success(self._persistence.persist_event(event, context))
+ return event
- self.get_success(
- inject_visibility_event(self.hs, room_id, "@resident:test", "joined")
- )
- before_event = self.get_success(
- inject_message_event(self.hs, room_id, "@resident:test", body="before")
- )
- join_event = self.get_success(
- inject_member_event(self.hs, room_id, "@joiner:test", "join")
- )
- during_event = self.get_success(
- inject_message_event(self.hs, room_id, "@resident:test", body="during")
- )
- leave_event = self.get_success(
- inject_member_event(self.hs, room_id, "@joiner:test", "leave")
- )
- after_event = self.get_success(
- inject_message_event(self.hs, room_id, "@resident:test", body="after")
+ def _inject_message(
+ self, user_id: str, content: Optional[JsonDict] = None
+ ) -> EventBase:
+ if content is None:
+ content = {"body": "testytest", "msgtype": "m.text"}
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": "m.room.message",
+ "sender": user_id,
+ "room_id": TEST_ROOM_ID,
+ "content": content,
+ },
)
- # We have to reload the events from the db, to ensure that prev_content is
- # populated.
- events_to_filter = [
- self.get_success(
- self.hs.get_storage_controllers().main.get_event(
- e.event_id,
- get_prev_content=True,
- )
- )
- for e in [
- before_event,
- join_event,
- during_event,
- leave_event,
- after_event,
- ]
- ]
-
- # Now run the events through the filter, and check that we can see the events
- # we expect, and that the membership prop is as expected.
- #
- # We deliberately do the queries for both users upfront; this simulates
- # concurrent queries on the server, and helps ensure that we aren't
- # accidentally serving the same event object (with the same unsigned.membership
- # property) to both users.
- joiner_filtered_events = self.get_success(
- filter_events_for_client(
- self.hs.get_storage_controllers(),
- "@joiner:test",
- events_to_filter,
- )
- )
- resident_filtered_events = self.get_success(
- filter_events_for_client(
- self.hs.get_storage_controllers(),
- "@resident:test",
- events_to_filter,
- )
+ event, unpersisted_context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- # The joiner should be able to seem the join and leave,
- # and messages sent between the two, but not before or after.
- self.assertEqual(
- [e.event_id for e in [join_event, during_event, leave_event]],
- [e.event_id for e in joiner_filtered_events],
- )
- self.assertEqual(
- ["join", "join", "leave"],
- [
- e.unsigned[EventUnsignedContentFields.MEMBERSHIP]
- for e in joiner_filtered_events
- ],
- )
+ self.get_success(self._persistence.persist_event(event, context))
+ return event
- # The resident user should see all the events.
- self.assertEqual(
- [
- e.event_id
- for e in [
- before_event,
- join_event,
- during_event,
- leave_event,
- after_event,
- ]
- ],
- [e.event_id for e in resident_filtered_events],
+ def _inject_outlier(self) -> EventBase:
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": "m.room.member",
+ "sender": "@test:user",
+ "state_key": "@test:user",
+ "room_id": TEST_ROOM_ID,
+ "content": {"membership": "join"},
+ },
)
- self.assertEqual(
- ["join", "join", "join", "join", "join"],
- [
- e.unsigned[EventUnsignedContentFields.MEMBERSHIP]
- for e in resident_filtered_events
- ],
+
+ event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
+ event.internal_metadata.outlier = True
+ self.get_success(
+ self._persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
+ )
)
+ return event
-class FilterEventsOutOfBandEventsForClientTestCase(
- unittest.FederatingHomeserverTestCase
-):
+class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
@@ -434,23 +341,15 @@ class FilterEventsOutOfBandEventsForClientTestCase(
)
# the invited user should be able to see both the invite and the rejection
- filtered_events = self.get_success(
- filter_events_for_client(
- self.hs.get_storage_controllers(),
- "@user:test",
- [invite_event, reject_event],
- )
- )
- self.assertEqual(
- [e.event_id for e in filtered_events],
- [e.event_id for e in [invite_event, reject_event]],
- )
self.assertEqual(
- ["invite", "leave"],
- [
- e.unsigned[EventUnsignedContentFields.MEMBERSHIP]
- for e in filtered_events
- ],
+ self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
+ )
+ ),
+ [invite_event, reject_event],
)
# other users should see neither
@@ -464,34 +363,3 @@ class FilterEventsOutOfBandEventsForClientTestCase(
),
[],
)
-
-
-async def inject_visibility_event(
- hs: HomeServer,
- room_id: str,
- sender: str,
- visibility: str,
-) -> EventBase:
- return await inject_event(
- hs,
- type="m.room.history_visibility",
- sender=sender,
- state_key="",
- room_id=room_id,
- content={"history_visibility": visibility},
- )
-
-
-async def inject_message_event(
- hs: HomeServer,
- room_id: str,
- sender: str,
- body: Optional[str] = "testytest",
-) -> EventBase:
- return await inject_event(
- hs,
- type="m.room.message",
- sender=sender,
- room_id=room_id,
- content={"body": body, "msgtype": "m.text"},
- )
diff --git a/tests/unittest.py b/tests/unittest.py
index 4aa7f56106..c2e120ffa6 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,8 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 Matrix.org Federation C.I.C
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -28,7 +26,6 @@ import logging
import secrets
import time
from typing import (
- AbstractSet,
Any,
Awaitable,
Callable,
@@ -110,7 +107,8 @@ class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""
@property
- def value(self) -> _ExcType: ...
+ def value(self) -> _ExcType:
+ ...
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
@@ -270,56 +268,6 @@ class TestCase(unittest.TestCase):
required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)
- def assertIncludes(
- self,
- actual_items: AbstractSet[str],
- expected_items: AbstractSet[str],
- exact: bool = False,
- message: Optional[str] = None,
- ) -> None:
- """
- Assert that all of the `expected_items` are included in the `actual_items`.
-
- This assert could also be called `assertContains`, `assertItemsInSet`
-
- Args:
- actual_items: The container
- expected_items: The items to check for in the container
- exact: Whether the actual state should be exactly equal to the expected
- state (no extras).
- message: Optional message to include in the failure message.
- """
- # Check that each set has the same items
- if exact and actual_items == expected_items:
- return
- # Check for a superset
- elif not exact and actual_items >= expected_items:
- return
-
- expected_lines: List[str] = []
- for expected_item in expected_items:
- is_expected_in_actual = expected_item in actual_items
- expected_lines.append(
- "{} {}".format(" " if is_expected_in_actual else "?", expected_item)
- )
-
- actual_lines: List[str] = []
- for actual_item in actual_items:
- is_actual_in_expected = actual_item in expected_items
- actual_lines.append(
- "{} {}".format("+" if is_actual_in_expected else " ", actual_item)
- )
-
- newline = "\n"
- expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}"
- actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}"
- first_message = (
- "Items must match exactly" if exact else "Some expected items are missing."
- )
- diff_message = f"{first_message}\n{expected_string}\n{actual_string}"
-
- self.fail(f"{diff_message}\n{message}")
-
def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG.
@@ -395,8 +343,6 @@ class HomeserverTestCase(TestCase):
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
- self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
-
# Honour the `use_frozen_dicts` config option. We have to do this
# manually because this is taken care of in the app `start` code, which
# we don't run. Plus we want to reset it on tearDown.
@@ -576,7 +522,6 @@ class HomeserverTestCase(TestCase):
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
- content_type: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[Iterable[CustomHeaderType]] = None,
@@ -595,9 +540,6 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it.
federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
-
- content_type: The content-type to use for the request. If not set then will default to
- application/json unless content_is_form is true.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -623,7 +565,6 @@ class HomeserverTestCase(TestCase):
request,
shorthand,
federation_auth_origin,
- content_type,
content_is_form,
await_result,
custom_headers,
@@ -690,13 +631,13 @@ class HomeserverTestCase(TestCase):
return self.successResultOf(deferred)
def get_failure(
- self, d: Awaitable[Any], exc: Type[_ExcType], by: float = 0.0
+ self, d: Awaitable[Any], exc: Type[_ExcType]
) -> _TypedFailure[_ExcType]:
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
- self.pump(by)
+ self.pump()
return self.failureResultOf(deferred, exc)
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
diff --git a/tests/util/__init__.py b/tests/util/__init__.py
index fcd2134c89..3d833a2e44 100644
--- a/tests/util/__init__.py
+++ b/tests/util/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/caches/__init__.py b/tests/util/caches/__init__.py
index abc9f6d57d..3d833a2e44 100644
--- a/tests/util/caches/__init__.py
+++ b/tests/util/caches/__init__.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2017 Vector Creations Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py
index d6b3d47422..f9e183b65c 100644
--- a/tests/util/caches/test_cached_call.py
+++ b/tests/util/caches/test_cached_call.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index f99f99237e..df1b4f587f 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 6af9dfaf56..d403014aad 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index e350967bba..57e7b39b69 100644
--- a/tests/util/caches/test_response_cache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 2dcf3a3412..897d0b6b31 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 13a4e6ddaa..65e0793d64 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -21,7 +20,6 @@
from contextlib import contextmanager
from os import PathLike
-from pathlib import Path
from typing import Generator, Optional, Union
from unittest.mock import patch
@@ -42,7 +40,7 @@ class DummyDistribution(metadata.Distribution):
def version(self) -> str:
return self._version
- def locate_file(self, path: Union[str, PathLike]) -> Path:
+ def locate_file(self, path: Union[str, PathLike]) -> PathLike:
raise NotImplementedError()
def read_text(self, filename: str) -> None:
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 5055e4aead..c598a6db06 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index e97e5cf77d..afbf5a926a 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2017 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 7a593cc683..5ef691cf03 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 7cbb1007da..c44db6bf8e 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -34,7 +33,8 @@ from tests import unittest
class UnblockFunction(Protocol):
- def __call__(self, pump_reactor: bool = True) -> None: ...
+ def __call__(self, pump_reactor: bool = True) -> None:
+ ...
class LinearizerTestCase(unittest.TestCase):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index f7c5f5faca..40380f6a3e 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 3f0d8139f8..2283a1dbba 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -383,34 +382,3 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
# the items should still be in the cache
self.assertEqual(cache.get("key1"), 1)
self.assertEqual(cache.get("key2"), 2)
-
-
-class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
- def test_invalidate_simple(self) -> None:
- cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
- cache["key1"] = 1
- cache["key2"] = 2
-
- cache.invalidate_on_extra_index("key1")
- self.assertEqual(cache.get("key1"), 1)
- self.assertEqual(cache.get("key2"), 2)
-
- cache.invalidate_on_extra_index("1")
- self.assertEqual(cache.get("key1"), None)
- self.assertEqual(cache.get("key2"), 2)
-
- def test_invalidate_multi(self) -> None:
- cache: LruCache[str, int] = LruCache(10, extra_index_cb=lambda k, v: str(v))
- cache["key1"] = 1
- cache["key2"] = 1
- cache["key3"] = 2
-
- cache.invalidate_on_extra_index("key1")
- self.assertEqual(cache.get("key1"), 1)
- self.assertEqual(cache.get("key2"), 1)
- self.assertEqual(cache.get("key3"), 2)
-
- cache.invalidate_on_extra_index("1")
- self.assertEqual(cache.get("key1"), None)
- self.assertEqual(cache.get("key2"), None)
- self.assertEqual(cache.get("key3"), 2)
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index e0507667a2..fa3fa3ff3c 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 7bb45f9bf2..2aeba9ab33 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 2c286c19a2..98296e611c 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 12f821d684..37f97a622d 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index af1199ef8a..3df053493b 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,5 +1,3 @@
-from parameterized import parameterized
-
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest
@@ -163,8 +161,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3))
- @parameterized.expand([(0,), (1000000000,)])
- def test_get_entities_changed(self, perf_factor: int) -> None:
+ def test_get_entities_changed(self) -> None:
"""
StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the
@@ -181,9 +178,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# get the ones after that point.
self.assertEqual(
cache.get_entities_changed(
- ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
- stream_pos=2,
- _perf_factor=perf_factor,
+ ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
),
{"bar@baz.net", "user@elsewhere.org"},
)
@@ -200,7 +195,6 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"not@here.website",
],
stream_pos=2,
- _perf_factor=perf_factor,
),
{"bar@baz.net", "user@elsewhere.org"},
)
@@ -216,7 +210,6 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"not@here.website",
],
stream_pos=0,
- _perf_factor=perf_factor,
),
{"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
)
@@ -224,11 +217,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(
- ["bar@baz.net"],
- stream_pos=2,
- _perf_factor=perf_factor,
- ),
+ cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
{"bar@baz.net"},
)
@@ -249,5 +238,5 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
- # Unknown entities will return None
- self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None)
+ # Unknown entities will return the stream start position.
+ self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 646fd2163e..b022cea0e8 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 30f0510c9f..ae2e9a334c 100644
--- a/tests/util/test_task_scheduler.py
+++ b/tests/util/test_task_scheduler.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2023 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
index 15575cc572..279798f1e5 100644
--- a/tests/util/test_threepids.py
+++ b/tests/util/test_threepids.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2020 Dirk Klimpel
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 1bc5a7e267..7290958897 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaec..b75a32b73f 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
diff --git a/tests/utils.py b/tests/utils.py
index 9fd26ef348..5798b04eef 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,7 +1,6 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
@@ -21,20 +20,7 @@
import atexit
import os
-import signal
-from types import FrameType, TracebackType
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- Optional,
- Tuple,
- Type,
- TypeVar,
- Union,
- overload,
-)
+from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
import attr
from typing_extensions import Literal, ParamSpec
@@ -134,11 +120,13 @@ def setupdb() -> None:
@overload
-def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ...
+def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
+ ...
@overload
-def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ...
+def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
+ ...
def default_config(
@@ -392,30 +380,3 @@ def checked_cast(type: Type[T], x: object) -> T:
"""
assert isinstance(x, type)
return x
-
-
-class TestTimeout(Exception):
- pass
-
-
-class test_timeout:
- def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
- if error_message is None:
- error_message = "test timed out after {}s.".format(seconds)
- self.seconds = seconds
- self.error_message = error_message
-
- def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
- raise TestTimeout(self.error_message)
-
- def __enter__(self) -> None:
- signal.signal(signal.SIGALRM, self.handle_timeout)
- signal.alarm(self.seconds)
-
- def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> None:
- signal.alarm(0)
|