diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py
index dcf336416c..b6bc1876b5 100644
--- a/tests/config/test_tls.py
+++ b/tests/config/test_tls.py
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-
import idna
-import yaml
from OpenSSL import SSL
@@ -39,58 +36,6 @@ class TestConfig(RootConfig):
class TLSConfigTests(TestCase):
- def test_warn_self_signed(self):
- """
- Synapse will give a warning when it loads a self-signed certificate.
- """
- config_dir = self.mktemp()
- os.mkdir(config_dir)
- with open(os.path.join(config_dir, "cert.pem"), "w") as f:
- f.write(
- """-----BEGIN CERTIFICATE-----
-MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
-BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
-Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
-QXV0b21hdGVkIFRlc3RpbmcgQXV0aG9yaXR5MSkwJwYJKoZIhvcNAQkBFhpzZWN1
-cml0eUB0d2lzdGVkbWF0cml4LmNvbTAgFw0xNzA3MTIxNDAxNTNaGA8yMTE3MDYx
-ODE0MDE1M1owgbcxCzAJBgNVBAYTAlRSMQ8wDQYDVQQIDAbDh29ydW0xFDASBgNV
-BAcMC0JhxZ9tYWvDp8SxMRIwEAYDVQQDDAlsb2NhbGhvc3QxHDAaBgNVBAoME1R3
-aXN0ZWQgTWF0cml4IExhYnMxJDAiBgNVBAsMG0F1dG9tYXRlZCBUZXN0aW5nIEF1
-dGhvcml0eTEpMCcGCSqGSIb3DQEJARYac2VjdXJpdHlAdHdpc3RlZG1hdHJpeC5j
-b20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDwT6kbqtMUI0sMkx4h
-I+L780dA59KfksZCqJGmOsMD6hte9EguasfkZzvCF3dk3NhwCjFSOvKx6rCwiteo
-WtYkVfo+rSuVNmt7bEsOUDtuTcaxTzIFB+yHOYwAaoz3zQkyVW0c4pzioiLCGCmf
-FLdiDBQGGp74tb+7a0V6kC3vMLFoM3L6QWq5uYRB5+xLzlPJ734ltyvfZHL3Us6p
-cUbK+3WTWvb4ER0W2RqArAj6Bc/ERQKIAPFEiZi9bIYTwvBH27OKHRz+KoY/G8zY
-+l+WZoJqDhupRAQAuh7O7V/y6bSP+KNxJRie9QkZvw1PSaGSXtGJI3WWdO12/Ulg
-epJpAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAJXEq5P9xwvP9aDkXIqzcD0L8sf8
-ewlhlxTQdeqt2Nace0Yk18lIo2oj1t86Y8jNbpAnZJeI813Rr5M7FbHCXoRc/SZG
-I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
-iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
-SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
-s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
------END CERTIFICATE-----"""
- )
-
- config = {
- "tls_certificate_path": os.path.join(config_dir, "cert.pem"),
- }
-
- t = TestConfig()
- t.read_config(config, config_dir_path="", data_dir_path="")
- t.read_tls_certificate()
-
- warnings = self.flushWarnings()
- self.assertEqual(len(warnings), 1)
- self.assertEqual(
- warnings[0]["message"],
- (
- "Self-signed TLS certificates will not be accepted by "
- "Synapse 1.0. Please either provide a valid certificate, "
- "or use Synapse's ACME support to provision one."
- ),
- )
-
def test_tls_client_minimum_default(self):
"""
The default client TLS version is 1.0.
@@ -202,48 +147,6 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
- def test_acme_disabled_in_generated_config_no_acme_domain_provied(self):
- """
- Checks acme is disabled by default.
- """
- conf = TestConfig()
- conf.read_config(
- yaml.safe_load(
- TestConfig().generate_config(
- "/config_dir_path",
- "my_super_secure_server",
- "/data_dir_path",
- tls_certificate_path="/tls_cert_path",
- tls_private_key_path="tls_private_key",
- acme_domain=None, # This is the acme_domain
- )
- ),
- "/config_dir_path",
- )
-
- self.assertFalse(conf.acme_enabled)
-
- def test_acme_enabled_in_generated_config_domain_provided(self):
- """
- Checks acme is enabled if the acme_domain arg is set to some string.
- """
- conf = TestConfig()
- conf.read_config(
- yaml.safe_load(
- TestConfig().generate_config(
- "/config_dir_path",
- "my_super_secure_server",
- "/data_dir_path",
- tls_certificate_path="/tls_cert_path",
- tls_private_key_path="tls_private_key",
- acme_domain="my_supe_secure_server", # This is the acme_domain
- )
- ),
- "/config_dir_path",
- )
-
- self.assertTrue(conf.acme_enabled)
-
def test_whitelist_idna_failure(self):
"""
The federation certificate whitelist will not allow IDNA domain names.
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 5d6cc2885f..024c5e963c 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -26,7 +26,7 @@ from .. import unittest
class AppServiceHandlerTestCase(unittest.TestCase):
- """ Tests the ApplicationServicesHandler. """
+ """Tests the ApplicationServicesHandler."""
def setUp(self):
self.mock_store = Mock()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 1908d3c2c6..7a8041ab44 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -27,7 +27,7 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
- """ Tests the directory service. """
+ """Tests the directory service."""
def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index d90a9fec91..dfb9b3a0fa 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -863,7 +863,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(builder.build(prev_event_ids, None))
+ event = self.get_success(
+ builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+ )
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 5330a9b34e..cdb41101b3 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -23,7 +23,7 @@ from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase):
- """ Tests profile management. """
+ """Tests profile management."""
def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index bd43190523..a9fd3036dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -27,8 +27,60 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
+class TestSpamChecker:
+ def __init__(self, config, api):
+ api.register_spam_checker_callbacks(
+ check_registration_for_spam=self.check_registration_for_spam,
+ )
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ pass
+
+
+class DenyAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.DENY
+
+
+class BanAll(TestSpamChecker):
+ async def check_registration_for_spam(
+ self,
+ email_threepid,
+ username,
+ request_info,
+ auth_provider_id,
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+
+class BanBadIdPUser(TestSpamChecker):
+ async def check_registration_for_spam(
+ self, email_threepid, username, request_info, auth_provider_id=None
+ ):
+ # Reject any user coming from CAS and whose username contains profanity
+ if auth_provider_id == "cas" and "flimflob" in username:
+ return RegistrationBehaviour.DENY
+ return RegistrationBehaviour.ALLOW
+
+
class RegistrationTestCase(unittest.HomeserverTestCase):
- """ Tests the RegistrationHandler. """
+ """Tests the RegistrationHandler."""
def make_homeserver(self, reactor, clock):
hs_config = self.default_config()
@@ -42,6 +94,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["limit_usage_by_mau"] = True
hs = self.setup_test_homeserver(config=hs_config)
+
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
return hs
def prepare(self, reactor, clock, hs):
@@ -465,34 +522,30 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".DenyAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_deny(self):
"""A spam checker can deny registration, which results in an error."""
-
- class DenyAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.DENY
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [DenyAll()]
-
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanAll",
+ }
+ ]
+ }
+ )
def test_spam_checker_shadow_ban(self):
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
-
- class BanAll:
- def check_registration_for_spam(
- self, email_threepid, username, request_info
- ):
- return RegistrationBehaviour.SHADOW_BAN
-
- # Configure a spam checker that denies all users.
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanAll()]
-
user_id = self.get_success(self.handler.register_user(localpart="user"))
# Get an access token.
@@ -512,22 +565,17 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(requester.shadow_banned)
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": TestSpamChecker.__module__ + ".BanBadIdPUser",
+ }
+ ]
+ }
+ )
def test_spam_checker_receives_sso_type(self):
"""Test rejecting registration based on SSO type"""
-
- class BanBadIdPUser:
- def check_registration_for_spam(
- self, email_threepid, username, request_info, auth_provider_id=None
- ):
- # Reject any user coming from CAS and whose username contains profanity
- if auth_provider_id == "cas" and "flimflob" in username:
- return RegistrationBehaviour.DENY
- return RegistrationBehaviour.ALLOW
-
- # Configure a spam checker that denies a certain user on a specific IdP
- spam_checker = self.hs.get_spam_checker()
- spam_checker.spam_checkers = [BanBadIdPUser()]
-
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
SynapseError,
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index 2c5e81531b..131d362ccc 100644
--- a/tests/handlers/test_space_summary.py
+++ b/tests/handlers/test_space_summary.py
@@ -11,10 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Optional
+from typing import Any, Iterable, Optional, Tuple
from unittest import mock
+from synapse.api.errors import AuthError
from synapse.handlers.space_summary import _child_events_comparison_key
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
from tests import unittest
@@ -79,3 +84,95 @@ class TestSpaceSummarySort(unittest.TestCase):
ev1 = _create_event("!abc:test", "a" * 51)
self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+
+class SpaceSummaryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
+ self.hs = hs
+ self.handler = self.hs.get_space_summary_handler()
+
+ self.user = self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ def _add_child(self, space_id: str, room_id: str, token: str) -> None:
+ """Add a child room to a space."""
+ self.helper.send_state(
+ space_id,
+ event_type="m.space.child",
+ body={"via": [self.hs.hostname]},
+ tok=token,
+ state_key=room_id,
+ )
+
+ def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None:
+ """Assert that the expected room IDs are in the response."""
+ self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms)
+
+ def _assert_events(
+ self, result: JsonDict, events: Iterable[Tuple[str, str]]
+ ) -> None:
+ """Assert that the expected parent / child room IDs are in the response."""
+ self.assertCountEqual(
+ [
+ (event.get("room_id"), event.get("state_key"))
+ for event in result["events"]
+ ],
+ events,
+ )
+
+ def test_simple_space(self):
+ """Test a simple space with a single room."""
+ space = self.helper.create_room_as(self.user, tok=self.token)
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(space, room, self.token)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, space))
+ # The result should have the space and the room in it, along with a link
+ # from space -> room.
+ self._assert_rooms(result, [space, room])
+ self._assert_events(result, [(space, room)])
+
+ def test_visibility(self):
+ """A user not in a space cannot inspect it."""
+ space = self.helper.create_room_as(self.user, tok=self.token)
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(space, room, self.token)
+
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+
+ # The user cannot see the space.
+ self.get_failure(self.handler.get_space_summary(user2, space), AuthError)
+
+ # Joining the room causes it to be visible.
+ self.helper.join(space, user2, tok=token2)
+ result = self.get_success(self.handler.get_space_summary(user2, space))
+
+ # The result should only have the space, but includes the link to the room.
+ self._assert_rooms(result, [space])
+ self._assert_events(result, [(space, room)])
+
+ def test_world_readable(self):
+ """A world-readable room is visible to everyone."""
+ space = self.helper.create_room_as(self.user, tok=self.token)
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(space, room, self.token)
+ self.helper.send_state(
+ space,
+ event_type="m.room.history_visibility",
+ body={"history_visibility": "world_readable"},
+ tok=self.token,
+ )
+
+ user2 = self.register_user("user2", "pass")
+
+ # The space should be visible, as well as the link to the room.
+ result = self.get_success(self.handler.get_space_summary(user2, space))
+ self._assert_rooms(result, [space])
+ self._assert_events(result, [(space, room)])
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index c8b43305f4..84f05f6c58 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -22,7 +22,7 @@ import tests.utils
class SyncTestCase(tests.unittest.HomeserverTestCase):
- """ Tests Sync Handler. """
+ """Tests Sync Handler."""
def prepare(self, reactor, clock, hs):
self.hs = hs
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index daac37abd8..549876dc85 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -312,15 +312,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ async def allow_all(user_profile):
+ # Allow all users.
+ return False
+
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
-
- class AllowAll:
- async def check_username_for_spam(self, user_profile):
- # Allow all users.
- return False
-
- spam_checker.spam_checkers = [AllowAll()]
+ spam_checker._check_username_for_spam_callbacks = [allow_all]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -328,12 +326,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll:
- async def check_username_for_spam(self, user_profile):
- # All users are spammy.
- return True
+ async def block_all(user_profile):
+ # All users are spammy.
+ return True
- spam_checker.spam_checkers = [BlockAll()]
+ spam_checker._check_username_for_spam_callbacks = [block_all]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 48ab3aa4e3..584da58371 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -224,7 +224,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
}
builder = factory.for_room_version(room_version, event_dict)
- join_event = self.get_success(builder.build(prev_event_ids, None))
+ join_event = self.get_success(
+ builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
+ )
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 852bda408c..2789d51546 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -23,7 +23,7 @@ from tests import unittest
class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
- """ Tests event streaming (GET /events). """
+ """Tests event streaming (GET /events)."""
servlets = [
events.register_servlets,
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 409f3949dc..597e4c67de 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -24,7 +24,7 @@ from tests import unittest
class PresenceTestCase(unittest.HomeserverTestCase):
- """ Tests presence REST API. """
+ """Tests presence REST API."""
user_id = "@sid:red"
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5b1096d091..e94566ffd7 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -64,7 +64,7 @@ class RoomBase(unittest.HomeserverTestCase):
class RoomPermissionsTestCase(RoomBase):
- """ Tests room permissions. """
+ """Tests room permissions."""
user_id = "@sid1:red"
rmcreator_id = "@notme:red"
@@ -377,7 +377,7 @@ class RoomPermissionsTestCase(RoomBase):
class RoomsMemberListTestCase(RoomBase):
- """ Tests /rooms/$room_id/members/list REST events."""
+ """Tests /rooms/$room_id/members/list REST events."""
user_id = "@sid1:red"
@@ -416,7 +416,7 @@ class RoomsMemberListTestCase(RoomBase):
class RoomsCreateTestCase(RoomBase):
- """ Tests /rooms and /rooms/$room_id REST events. """
+ """Tests /rooms and /rooms/$room_id REST events."""
user_id = "@sid1:red"
@@ -502,7 +502,7 @@ class RoomsCreateTestCase(RoomBase):
class RoomTopicTestCase(RoomBase):
- """ Tests /rooms/$room_id/topic REST events. """
+ """Tests /rooms/$room_id/topic REST events."""
user_id = "@sid1:red"
@@ -566,7 +566,7 @@ class RoomTopicTestCase(RoomBase):
class RoomMemberStateTestCase(RoomBase):
- """ Tests /rooms/$room_id/members/$user_id/state REST events. """
+ """Tests /rooms/$room_id/members/$user_id/state REST events."""
user_id = "@sid1:red"
@@ -790,7 +790,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
class RoomMessagesTestCase(RoomBase):
- """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
+ """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events."""
user_id = "@sid1:red"
@@ -838,7 +838,7 @@ class RoomMessagesTestCase(RoomBase):
class RoomInitialSyncTestCase(RoomBase):
- """ Tests /rooms/$room_id/initialSync. """
+ """Tests /rooms/$room_id/initialSync."""
user_id = "@sid1:red"
@@ -879,7 +879,7 @@ class RoomInitialSyncTestCase(RoomBase):
class RoomMessageListTestCase(RoomBase):
- """ Tests /rooms/$room_id/messages REST events. """
+ """Tests /rooms/$room_id/messages REST events."""
user_id = "@sid1:red"
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 0aad48a162..44e22ca999 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -26,7 +26,7 @@ PATH_PREFIX = "/_matrix/client/api/v1"
class RoomTypingTestCase(unittest.HomeserverTestCase):
- """ Tests /rooms/$room_id/typing/$user_id REST API. """
+ """Tests /rooms/$room_id/typing/$user_id REST API."""
user_id = "@sid:red"
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index b52f78ba69..012910f136 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -558,3 +558,53 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Store the next batch for the next request.
self.next_batch = channel.json_body["next_batch"]
+
+
+class SyncCacheTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def test_noop_sync_does_not_tightloop(self):
+ """If the sync times out, we shouldn't cache the result
+
+ Essentially a regression test for #8518.
+ """
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # we should immediately get an initial sync response
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # now, make an incremental sync request, with a timeout
+ next_batch = channel.json_body["next_batch"]
+ channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}&timeout=10000",
+ access_token=self.tok,
+ await_result=False,
+ )
+ # that should block for 10 seconds
+ with self.assertRaises(TimedOutException):
+ channel.await_result(timeout_ms=9900)
+ channel.await_result(timeout_ms=200)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # we expect the next_batch in the result to be the same as before
+ self.assertEqual(channel.json_body["next_batch"], next_batch)
+
+ # another incremental sync should also block.
+ channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}&timeout=10000",
+ access_token=self.tok,
+ await_result=False,
+ )
+ # that should block for 10 seconds
+ with self.assertRaises(TimedOutException):
+ channel.await_result(timeout_ms=9900)
+ channel.await_result(timeout_ms=200)
+ self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4a213d13dd..95e7075841 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -27,6 +27,7 @@ from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
+from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
from synapse.rest import admin
from synapse.rest.client.v1 import login
@@ -535,6 +536,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
+ load_legacy_spam_checkers(hs)
+
def default_config(self):
config = default_config("test")
diff --git a/tests/server.py b/tests/server.py
index 9df8cda24f..f32d8dc375 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -138,21 +138,19 @@ class FakeChannel:
def transport(self):
return self
- def await_result(self, timeout: int = 100) -> None:
+ def await_result(self, timeout_ms: int = 1000) -> None:
"""
Wait until the request is finished.
"""
+ end_time = self._reactor.seconds() + timeout_ms / 1000.0
self._reactor.run()
- x = 0
while not self.is_finished():
# If there's a producer, tell it to resume producing so we get content
if self._producer:
self._producer.resumeProducing()
- x += 1
-
- if x > timeout:
+ if self._reactor.seconds() > end_time:
raise TimedOutException("Timed out waiting for request to finish.")
self._reactor.advance(0.1)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 3b45a7efd8..ddad44bd6c 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -27,7 +27,7 @@ from tests.utils import TestHomeServer, default_config
class SQLBaseStoreTestCase(unittest.TestCase):
- """ Test the "simple" SQL generating methods in SQLBaseStore. """
+ """Test the "simple" SQL generating methods in SQLBaseStore."""
def setUp(self):
self.db_pool = Mock(spec=["runInteraction"])
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index bb31ab756d..dbacce4380 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -232,9 +232,14 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
- async def build(self, prev_event_ids, auth_event_ids):
+ async def build(
+ self,
+ prev_event_ids,
+ auth_event_ids,
+ depth: Optional[int] = None,
+ ):
built_event = await self._base_builder.build(
- prev_event_ids, auth_event_ids
+ prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
built_event._event_id = self._event_id
@@ -251,6 +256,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def type(self):
return self._base_builder.type
+ @property
+ def internal_metadata(self):
+ return self._base_builder.internal_metadata
+
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
|