diff --git a/tests/config/test_base.py b/tests/config/test_base.py
new file mode 100644
index 0000000000..42ee5f56d9
--- /dev/null
+++ b/tests/config/test_base.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+import tempfile
+
+from synapse.config import ConfigError
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class BaseConfigTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+
+ def test_loading_missing_templates(self):
+ # Use a temporary directory that exists on the system, but that isn't likely to
+ # contain template files
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Attempt to load an HTML template from our custom template directory
+ template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+
+ # If no errors, we should've gotten the default template instead
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"error_description": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_custom_templates(self):
+ # Use a temporary directory that exists on the system
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ # Create a temporary bogus template file
+ with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
+ # Get temporary file's filename
+ template_filename = os.path.basename(tmp_template.name)
+
+ # Write a custom HTML template
+ contents = b"{{ test_variable }}"
+ tmp_template.write(contents)
+ tmp_template.flush()
+
+ # Attempt to load the template from our custom template directory
+ template = (
+ self.hs.config.read_templates([template_filename], tmp_dir)
+ )[0]
+
+ # Render the template
+ a_random_string = random_string(5)
+ html_content = template.render({"test_variable": a_random_string})
+
+ # Check that our string exists in the template
+ self.assertIn(
+ a_random_string,
+ html_content,
+ "Template file did not contain our test string",
+ )
+
+ def test_loading_template_from_nonexistent_custom_directory(self):
+ with self.assertRaises(ConfigError):
+ self.hs.config.read_templates(
+ ["some_filename.html"], "a_nonexistent_directory"
+ )
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index b8ca118716..9bd515080c 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -79,9 +79,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -110,9 +112,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -148,9 +152,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable(None)
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -204,9 +210,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -234,9 +242,11 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ fed_transport.client.get_json = Mock(
+ side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+ )
handler.federation_handler.do_invite_join = Mock(
- return_value=make_awaitable(("", 1))
+ side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
)
d = handler._remote_join(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d70e1fc608..b609b30d4a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase):
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- yield self.store.create_profile(self.frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler()
self.hs = hs
@@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
+ yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield defer.ensureDeferred(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 64afd581bc..e01de158e5 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
- self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
None
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 0b191d13c6..d4e7fa1293 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -140,7 +153,27 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
@@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e74bddc1e5..c6c6edeac2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -21,13 +21,13 @@
import json
from urllib import parse as urlparse
-from mock import Mock
+from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account
+from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
@@ -2059,3 +2059,158 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ShadowBannedTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..2fc3a60fc5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 2858d13558..23db821fb7 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 98b74890d5..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -207,7 +208,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
- yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
- yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ )
+ yield defer.ensureDeferred(
+ self.store.set_appservice_state(service, ApplicationServiceState.UP)
+ )
rows = yield self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
@@ -349,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index efcaeef1e7..13bcac743a 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename", values={"columname": "Value"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename", values={"columname": "Value"}
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_insert(
- table="tablename",
- # Use OrderedDict() so we can assert on the SQL generated
- values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_insert(
+ table="tablename",
+ # Use OrderedDict() so we can assert on the SQL generated
+ values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+ )
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 3fab5a5248..43639ca286 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID("alice", "test")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password")
- self.requester = Requester(self.user, None, False, None, None)
+ self.requester = Requester(self.user, None, False, False, None, None)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
# Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 857db071d4..238bad5b45 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db_pool.simple_insert(
- "events",
- {
- "stream_ordering": so,
- "received_ts": ts,
- "event_id": "event%i" % so,
- "type": "",
- "room_id": "",
- "content": "",
- "processed": True,
- "outlier": False,
- "topological_ordering": 0,
- "depth": 0,
- },
+ return defer.ensureDeferred(
+ self.store.db_pool.simple_insert(
+ "events",
+ {
+ "stream_ordering": so,
+ "received_ts": ts,
+ "event_id": "event%i" % so,
+ "type": "",
+ "room_id": "",
+ "content": "",
+ "processed": True,
+ "outlier": False,
+ "topological_ordering": 0,
+ "depth": 0,
+ },
+ )
)
# start with the base case where there are no events in the table
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..7a05194653 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..fbf8af940a 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
- yield self.store.create_profile(self.user.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
users, total = yield self.store.get_users_paginate(
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..9d5b8aa47d 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
@@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(self.u_frank.localpart)
+ yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index a6012c973d..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = store.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
-
- # Purge everything before this topological token
- purge = defer.ensureDeferred(
- storage.purge_events.purge_history(self.room_id, event, True)
+ event = self.get_success(
+ store.get_topological_token_for_event(last["event_id"])
)
- self.pump()
- self.assertEqual(self.successResultOf(purge), None)
- # Try and get the events
- get_first = store.get_event(first["event_id"])
- get_second = store.get_event(second["event_id"])
- get_third = store.get_event(third["event_id"])
- get_last = store.get_event(last["event_id"])
- self.pump()
+ # Purge everything before this topological token
+ self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
- self.failureResultOf(get_first)
- self.failureResultOf(get_second)
- self.failureResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+ self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+ self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
"""
@@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = storage.get_topological_token_for_event(last["event_id"])
- self.pump()
- event = self.successResultOf(event)
+ event = self.get_success(
+ storage.get_topological_token_for_event(last["event_id"])
+ )
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
@@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
- get_first = storage.get_event(first["event_id"])
- get_second = storage.get_event(second["event_id"])
- get_third = storage.get_event(third["event_id"])
- get_last = storage.get_event(last["event_id"])
- self.pump()
-
- # Nothing is deleted.
- self.successResultOf(get_first)
- self.successResultOf(get_second)
- self.successResultOf(get_third)
- self.successResultOf(get_last)
+ self.get_success(storage.get_event(first["event_id"]))
+ self.get_success(storage.get_event(second["event_id"]))
+ self.get_success(storage.get_event(third["event_id"]))
+ self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 840db66072..58f827d8d3 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 17c9da4838..d98fe8754d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership
user = UserID("alice", "test")
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index f2fa42bfb9..4a4548433f 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -42,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
user_id = UserID("us", "test")
- our_user = Requester(user_id, None, False, None, None)
+ our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room_deferred = ensureDeferred(
room_creator.create_room(
diff --git a/tests/test_server.py b/tests/test_server.py
index d628070e48..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -178,7 +178,6 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
- self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
class OptionsResourceTests(unittest.TestCase):
diff --git a/tests/unittest.py b/tests/unittest.py
index d0bba3ddef..7b80999a74 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase):
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
+ UserID.from_string(self.helper.auth_user_id),
+ 1,
+ False,
+ False,
+ None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -540,7 +544,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
event, context = self.get_success(
event_creator.create_event(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..0363735d4f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 4f4da29a98..8491f7cc83 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase):
"_--something==_",
"...--==-18913",
"8Dj2odd-e9asd.cd==_--ddas-secret-",
- # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
- # To be removed in a future release
- "SECRET:1234567890",
]
bad = [
|