diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 5d6cc2885f..024c5e963c 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -26,7 +26,7 @@ from .. import unittest
class AppServiceHandlerTestCase(unittest.TestCase):
- """ Tests the ApplicationServicesHandler. """
+ """Tests the ApplicationServicesHandler."""
def setUp(self):
self.mock_store = Mock()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 70ebaf2c78..b9812b67e4 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -27,7 +27,7 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
- """ Tests the directory service. """
+ """Tests the directory service."""
def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 61a00130b8..e0a24824cc 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
devices = self.get_success(
- self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
+ self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user, "device123"
+ )
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = self.get_success(
- self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
+ self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user, "device123"
+ )
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# fetch the signed keys/devices and make sure that the signatures are there
ret = self.get_success(
self.handler.query_devices(
- {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ {"device_keys": {local_user: [], other_user: []}},
+ 0,
+ local_user,
+ "device123",
)
)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index d90a9fec91..dfb9b3a0fa 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -863,7 +863,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(builder.build(prev_event_ids, None))
+ event = self.get_success(
+ builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+ )
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8834d21f0d..e8f9294118 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -23,7 +23,7 @@ from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase):
- """ Tests profile management. """
+ """Tests profile management."""
def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index aa2999f174..a0147421f1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,6 +17,7 @@ from unittest.mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
+from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.rest.client.v2_alpha.register import (
_map_email_to_displayname,
register_servlets,
@@ -32,8 +33,93 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
+class TestSpamChecker:
+ def __init__(self, config, api):
+ api.register_spam_checker_callbacks(
+ check_registration_for_spam=self.check_registration_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ pass
+
+
+class DenyAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.DENY
+
+
+class BanAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+
+class BanBadIdPUser(TestSpamChecker):
+ async def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+
+class TestLegacyRegistrationSpamChecker:
+ def __init__(self, config, api):
+ pass
+
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ ):
+ pass
+
+
+class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ ):
+ return RegistrationBehaviour.ALLOW
+
+
+class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ ):
+ return RegistrationBehaviour.DENY
+
+
class RegistrationTestCase(unittest.HomeserverTestCase):
- """ Tests the RegistrationHandler. """
+ """Tests the RegistrationHandler."""
servlets = [
register_servlets,
@@ -51,6 +137,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config)
+
+ load_legacy_spam_checkers(hs)
+
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
return hs
def prepare(self, reactor, clock, hs):
@@ -474,34 +567,70 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".DenyAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_deny(self):
"""A spam checker can deny registration, which results in an error."""
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
- class DenyAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.DENY
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [DenyAll()]
+ @override_config(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".LegacyAllowAll",
+ }
+ ]
+ }
+ )
+ def test_spam_checker_legacy_allow(self):
+ """Tests that a legacy spam checker implementing the legacy 3-arg version of the
+ check_registration_for_spam callback is correctly called.
+
+ In this test and the following one we test both success and failure to make sure
+ any failure comes from the spam checker (and not something else failing in the
+ call stack) and any success comes from the spam checker (and not because a
+ misconfiguration prevented it from being loaded).
+ """
+ self.get_success(self.handler.register_user(localpart="user"))
+ @override_config(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".LegacyDenyAll",
+ }
+ ]
+ }
+ )
+ def test_spam_checker_legacy_deny(self):
+ """Tests that a legacy spam checker implementing the legacy 3-arg version of the
+ check_registration_for_spam callback is correctly called.
+
+ In this test and the previous one we test both success and failure to make sure
+ any failure comes from the spam checker (and not something else failing in the
+ call stack) and any success comes from the spam checker (and not because a
+ misconfiguration prevented it from being loaded).
+ """
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_shadow_ban(self):
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
-
- class BanAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.SHADOW_BAN
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanAll()]
-
user_id = self.get_success(self.handler.register_user(localpart="user"))
# Get an access token.
@@ -521,22 +650,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanBadIdPUser",
+ }
+ ]
+ }
+ )
def test_spam_checker_receives_sso_type(self):
"""Test rejecting registration based on SSO type"""
-
- class BanBadIdPUser:
- def check_registration_for_spam(
- self, email_threepid, username, request_info, auth_provider_id=None
- ):
- # Reject any user coming from CAS and whose username contains profanity
- if auth_provider_id == "cas" and "flimflob" in username:
- return RegistrationBehaviour.DENY
- return RegistrationBehaviour.ALLOW
-
- # Configure a spam checker that denies a certain user on a specific IdP
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanBadIdPUser()]
-
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
SynapseError,
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index 2c5e81531b..131d362ccc 100644
--- a/tests/handlers/test_space_summary.py
+++ b/tests/handlers/test_space_summary.py
@@ -11,10 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Optional
+from typing import Any, Iterable, Optional, Tuple
from unittest import mock
+from synapse.api.errors import AuthError
from synapse.handlers.space_summary import _child_events_comparison_key
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
from tests import unittest
@@ -79,3 +84,95 @@ class TestSpaceSummarySort(unittest.TestCase):
ev1 = _create_event("!abc:test", "a" * 51)
self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+
+class SpaceSummaryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
+ self.hs = hs
+ self.handler = self.hs.get_space_summary_handler()
+
+ self.user = self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ def _add_child(self, space_id: str, room_id: str, token: str) -> None:
+ """Add a child room to a space."""
+ self.helper.send_state(
+ space_id,
+ event_type="m.space.child",
+ body={"via": [self.hs.hostname]},
+ tok=token,
+ state_key=room_id,
+ )
+
+ def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None:
+ """Assert that the expected room IDs are in the response."""
+ self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms)
+
+ def _assert_events(
+ self, result: JsonDict, events: Iterable[Tuple[str, str]]
+ ) -> None:
+ """Assert that the expected parent / child room IDs are in the response."""
+ self.assertCountEqual(
+ [
+ (event.get("room_id"), event.get("state_key"))
+ for event in result["events"]
+ ],
+ events,
+ )
+
+ def test_simple_space(self):
+ """Test a simple space with a single room."""
+ space = self.helper.create_room_as(self.user, tok=self.token)
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(space, room, self.token)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, space))
+ # The result should have the space and the room in it, along with a link
+ # from space -> room.
+ self._assert_rooms(result, [space, room])
+ self._assert_events(result, [(space, room)])
+
+ def test_visibility(self):
+ """A user not in a space cannot inspect it."""
+ space = self.helper.create_room_as(self.user, tok=self.token)
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(space, room, self.token)
+
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+
+ # The user cannot see the space.
+ self.get_failure(self.handler.get_space_summary(user2, space), AuthError)
+
+ # Joining the room causes it to be visible.
+ self.helper.join(space, user2, tok=token2)
+ result = self.get_success(self.handler.get_space_summary(user2, space))
+
+ # The result should only have the space, but includes the link to the room.
+ self._assert_rooms(result, [space])
+ self._assert_events(result, [(space, room)])
+
+ def test_world_readable(self):
+ """A world-readable room is visible to everyone."""
+ space = self.helper.create_room_as(self.user, tok=self.token)
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(space, room, self.token)
+ self.helper.send_state(
+ space,
+ event_type="m.room.history_visibility",
+ body={"history_visibility": "world_readable"},
+ tok=self.token,
+ )
+
+ user2 = self.register_user("user2", "pass")
+
+ # The space should be visible, as well as the link to the room.
+ result = self.get_success(self.handler.get_space_summary(user2, space))
+ self._assert_rooms(result, [space])
+ self._assert_events(result, [(space, room)])
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index c8b43305f4..84f05f6c58 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -22,7 +22,7 @@ import tests.utils
class SyncTestCase(tests.unittest.HomeserverTestCase):
- """ Tests Sync Handler. """
+ """Tests Sync Handler."""
def prepare(self, reactor, clock, hs):
self.hs = hs
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 453686cc81..6bb13c1c04 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -312,15 +312,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ async def allow_all(user_profile):
+ # Allow all users.
+ return False
+
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
-
- class AllowAll:
- async def check_username_for_spam(self, user_profile):
- # Allow all users.
- return False
-
- spam_checker.spam_checkers = [AllowAll()]
+ spam_checker._check_username_for_spam_callbacks = [allow_all]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -328,12 +326,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll:
- async def check_username_for_spam(self, user_profile):
- # All users are spammy.
- return True
+ async def block_all(user_profile):
+ # All users are spammy.
+ return True
- spam_checker.spam_checkers = [BlockAll()]
+ spam_checker._check_username_for_spam_callbacks = [block_all]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|