diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index f76fea4f66..cccff7af26 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -25,7 +25,9 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
+from synapse.appservice import ApplicationService
from synapse.storage.databases.main.registration import TokenLookupResult
+from synapse.types import Requester
from tests import unittest
from tests.test_utils import simple_async_mock
@@ -217,7 +219,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key,
+ key=self.hs.config.key.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
@@ -239,7 +241,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key,
+ key=self.hs.config.key.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
@@ -290,6 +292,66 @@ class AuthTestCase(unittest.HomeserverTestCase):
# Real users not allowed
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
+ self.auth_blocking._max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._track_appservice_user_ips = False
+
+ self.store.get_monthly_active_count = simple_async_mock(100)
+ self.store.user_last_seen_monthly_active = simple_async_mock()
+ self.store.is_trial_user = simple_async_mock()
+
+ appservice = ApplicationService(
+ "abcd",
+ self.hs.config.server_name,
+ id="1234",
+ namespaces={
+ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
+ },
+ sender="@appservice:sender",
+ )
+ requester = Requester(
+ user="@appservice:server",
+ access_token_id=None,
+ device_id="FOOBAR",
+ is_guest=False,
+ shadow_banned=False,
+ app_service=appservice,
+ authenticated_entity="@appservice:server",
+ )
+ self.get_success(self.auth.check_auth_blocking(requester=requester))
+
+ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
+ self.auth_blocking._max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._track_appservice_user_ips = True
+
+ self.store.get_monthly_active_count = simple_async_mock(100)
+ self.store.user_last_seen_monthly_active = simple_async_mock()
+ self.store.is_trial_user = simple_async_mock()
+
+ appservice = ApplicationService(
+ "abcd",
+ self.hs.config.server_name,
+ id="1234",
+ namespaces={
+ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
+ },
+ sender="@appservice:sender",
+ )
+ requester = Requester(
+ user="@appservice:server",
+ access_token_id=None,
+ device_id="FOOBAR",
+ is_guest=False,
+ shadow_banned=False,
+ app_service=appservice,
+ authenticated_entity="@appservice:server",
+ )
+ self.get_failure(
+ self.auth.check_auth_blocking(requester=requester), ResourceLimitError
+ )
+
def test_reserved_threepid(self):
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index d66aeb00eb..19eb4c79d0 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -172,7 +172,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
# We don't want our tests to actually report statistics, so check
# that it's not enabled
- assert not hs.config.report_stats
+ assert not hs.config.metrics.report_stats
# This starts the needed data collection that we rely on to calculate
# R30v2 metrics.
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 903c69127d..ef6c2beec7 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -52,10 +52,10 @@ class ConfigLoadingTestCase(unittest.TestCase):
hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
- if len(config.macaroon_secret_key) < 5:
+ if len(config.key.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
- "was: %r" % (config.macaroon_secret_key,)
+ "was: %r" % (config.key.macaroon_secret_key,)
)
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
@@ -63,10 +63,10 @@ class ConfigLoadingTestCase(unittest.TestCase):
hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
- if len(config.macaroon_secret_key) < 5:
+ if len(config.key.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
- "was: %r" % (config.macaroon_secret_key,)
+ "was: %r" % (config.key.macaroon_secret_key,)
)
def test_load_succeeds_if_macaroon_secret_key_missing(self):
@@ -101,7 +101,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
# The default Metrics Flags are off by default.
config = HomeServerConfig.load_config("", ["-c", self.file])
- self.assertFalse(config.metrics_flags.known_servers)
+ self.assertFalse(config.metrics.metrics_flags.known_servers)
def generate_config(self):
with redirect_stdout(StringIO()):
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index 3c7bb32e07..1b63e1adfd 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -30,7 +30,7 @@ class RatelimitConfigTestCase(TestCase):
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
- config_obj = config.rc_federation
+ config_obj = config.ratelimiting.rc_federation
self.assertEqual(config_obj.window_size, 20000)
self.assertEqual(config_obj.sleep_limit, 693)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 5f3350e490..12857053e7 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -67,7 +67,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_type)
v.satisfy_general(verify_nonce)
v.satisfy_general(verify_guest)
- v.verify(macaroon, self.hs.config.macaroon_secret_key)
+ v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token(
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index a0a48b564e..6a2e76ca4a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -405,7 +405,9 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
rd_config = RoomDirectoryConfig()
rd_config.read_config(config)
- self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed
+ self.hs.config.roomdirectory.is_alias_creation_allowed = (
+ rd_config.is_alias_creation_allowed
+ )
return hs
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 6c67a16de9..936ebf3dde 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -308,7 +308,12 @@ class FederationTestCase(unittest.HomeserverTestCase):
async def get_event_auth(
destination: str, room_id: str, event_id: str
) -> List[EventBase]:
- return auth_events
+ return [
+ event_from_pdu_json(
+ ae.get_pdu_json(), room_version=room_version, outlier=True
+ )
+ for ae in auth_events
+ ]
self.handler.federation_client.get_event_auth = get_event_auth
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 2928c4f48c..57cc3e2646 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
import synapse.types
from synapse.api.errors import AuthError, SynapseError
+from synapse.rest import admin
from synapse.types import UserID
from tests import unittest
@@ -25,6 +26,8 @@ from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase):
"""Tests profile management."""
+ servlets = [admin.register_servlets]
+
def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
self.mock_registry = Mock()
@@ -46,11 +49,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.frank = UserID.from_string("@1234ABCD:test")
+ self.frank = UserID.from_string("@1234abcd:test")
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- self.get_success(self.store.create_profile(self.frank.localpart))
+ self.get_success(self.register_user(self.frank.localpart, "frankpassword"))
self.handler = hs.get_profile_handler()
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 732a12c9bd..5de89c873b 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -23,7 +23,7 @@ from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
- self.event_source = hs.get_event_sources().sources["receipt"]
+ self.event_source = hs.get_event_sources().sources.receipt
# In the first param of _test_filters_hidden we use "hidden" instead of
# ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 1ba4c05b9b..24b7ef6efc 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -118,7 +118,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 0)
# Disable stats
- self.hs.config.stats_enabled = False
+ self.hs.config.stats.stats_enabled = False
self.handler.stats_enabled = False
u1 = self.register_user("u1", "pass")
@@ -134,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 0)
# Enable stats
- self.hs.config.stats_enabled = True
+ self.hs.config.stats.stats_enabled = True
self.handler.stats_enabled = True
# Do the initial population of the user directory via the background update
@@ -469,7 +469,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
behaviour eventually to still keep current rows.
"""
- self.hs.config.stats_enabled = False
+ self.hs.config.stats.stats_enabled = False
self.handler.stats_enabled = False
u1 = self.register_user("u1", "pass")
@@ -481,7 +481,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertIsNone(self._get_current_stats("room", r1))
self.assertIsNone(self._get_current_stats("user", u1))
- self.hs.config.stats_enabled = True
+ self.hs.config.stats.stats_enabled = True
self.handler.stats_enabled = True
self._perform_background_initial_update()
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index fa3cff598e..000f9b9fde 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -89,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_typing_handler()
- self.event_source = hs.get_event_sources().sources["typing"]
+ self.event_source = hs.get_event_sources().sources.typing
self.datastore = hs.get_datastore()
self.datastore.get_destination_retry_timings = Mock(
@@ -171,7 +171,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ )
)
self.assertEquals(
events[0],
@@ -239,7 +241,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ )
)
self.assertEquals(
events[0],
@@ -276,7 +280,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=0,
+ limit=None,
+ room_ids=[OTHER_ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(events[0], [])
self.assertEquals(events[1], 0)
@@ -324,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ )
)
self.assertEquals(
events[0],
@@ -350,7 +362,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=0,
+ limit=None,
+ room_ids=[ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
@@ -369,7 +387,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 2)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=1,
+ limit=None,
+ room_ids=[ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
@@ -392,7 +416,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 3)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=0,
+ limit=None,
+ room_ids=[ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index ae88ed89aa..266333c553 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
from unittest.mock import Mock
+from urllib.parse import quote
from twisted.internet import defer
@@ -20,6 +22,7 @@ from synapse.api.constants import UserTypes
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.rest.client import login, room, user_directory
from synapse.storage.roommember import ProfileInfo
+from synapse.types import create_requester
from tests import unittest
from tests.unittest import override_config
@@ -32,7 +35,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
+ synapse.rest.admin.register_servlets,
room.register_servlets,
]
@@ -130,6 +133,44 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
self.store.remove_from_user_dir.called_once_with(r_user_id)
+ def test_reactivation_makes_regular_user_searchable(self):
+ user = self.register_user("regular", "pass")
+ user_token = self.login(user, "pass")
+ admin_user = self.register_user("admin", "pass", admin=True)
+ admin_token = self.login(admin_user, "pass")
+
+ # Ensure the regular user is publicly visible and searchable.
+ self.helper.create_room_as(user, is_public=True, tok=user_token)
+ s = self.get_success(self.handler.search_users(admin_user, user, 10))
+ self.assertEqual(len(s["results"]), 1)
+ self.assertEqual(s["results"][0]["user_id"], user)
+
+ # Deactivate the user and check they're not searchable.
+ deactivate_handler = self.hs.get_deactivate_account_handler()
+ self.get_success(
+ deactivate_handler.deactivate_account(
+ user, erase_data=False, requester=create_requester(admin_user)
+ )
+ )
+ s = self.get_success(self.handler.search_users(admin_user, user, 10))
+ self.assertEqual(s["results"], [])
+
+ # Reactivate the user
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{quote(user)}",
+ access_token=admin_token,
+ content={"deactivated": False, "password": "pass"},
+ )
+ self.assertEqual(channel.code, 200)
+ user_token = self.login(user, "pass")
+ self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check they're searchable.
+ s = self.get_success(self.handler.search_users(admin_user, user, 10))
+ self.assertEqual(len(s["results"]), 1)
+ self.assertEqual(s["results"][0]["user_id"], user)
+
def test_private_room(self):
"""
A user can be searched for only by people that are either in a public
@@ -285,7 +326,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r
- def get_users_in_public_rooms(self):
+ def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
r = self.get_success(
self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
@@ -296,7 +337,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
retval.append((i["user_id"], i["room_id"]))
return retval
- def get_users_who_share_private_rooms(self):
+ def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
return self.get_success(
self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
@@ -410,7 +451,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
visible.
"""
self.handler.search_all_users = True
- self.hs.config.user_directory_search_all_users = True
+ self.hs.config.userdirectory.user_directory_search_all_users = True
u1 = self.register_user("user1", "pass")
self.register_user("user2", "pass")
@@ -566,7 +607,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
return hs
def test_disabling_room_list(self):
- self.config.user_directory_search_enabled = True
+ self.config.userdirectory.user_directory_search_enabled = True
# First we create a room with another user so that user dir is non-empty
# for our user
@@ -583,7 +624,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.assertTrue(len(channel.json_body["results"]) > 0)
# Disable user directory and check search returns nothing
- self.config.user_directory_search_enabled = False
+ self.config.userdirectory.user_directory_search_enabled = False
channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 768c2ba4ea..391196425c 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
- channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ channel = make_request(
+ self.reactor, FakeSite(resource, self.reactor), "GET", "/"
+ )
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
@@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
- channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ channel = make_request(
+ self.reactor, FakeSite(resource, self.reactor), "GET", "/"
+ )
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 1160716929..f73fcd684e 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
site.site_tag = "test-site"
site.server_version_string = "Server v1"
- request = SynapseRequest(FakeChannel(site, None))
+ site.reactor = Mock()
+ request = SynapseRequest(FakeChannel(site, None), site)
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest.
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 7dd519cd44..9d38974fba 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -43,6 +43,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler()
self.sync_handler = homeserver.get_sync_handler()
+ self.auth_handler = homeserver.get_auth_handler()
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
@@ -89,6 +90,77 @@ class ModuleApiTestCase(HomeserverTestCase):
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
+ def test_get_user_ip_and_agents(self):
+ user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
+
+ # Initially, we should have no ip/agent for our user.
+ info = self.get_success(self.module_api.get_user_ip_and_agents(user_id))
+ self.assertEqual(info, [])
+
+ # Insert a first ip, agent. We should be able to retrieve it.
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_1", "user_agent_1", "device_1", None
+ )
+ )
+ info = self.get_success(self.module_api.get_user_ip_and_agents(user_id))
+
+ self.assertEqual(len(info), 1)
+ last_seen_1 = info[0].last_seen
+
+ # Insert a second ip, agent at a later date. We should be able to retrieve it.
+ last_seen_2 = last_seen_1 + 10000
+ print("%s => %s" % (last_seen_1, last_seen_2))
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip_2", "user_agent_2", "device_2", last_seen_2
+ )
+ )
+ info = self.get_success(self.module_api.get_user_ip_and_agents(user_id))
+
+ self.assertEqual(len(info), 2)
+ ip_1_seen = False
+ ip_2_seen = False
+
+ for i in info:
+ if i.ip == "ip_1":
+ ip_1_seen = True
+ self.assertEqual(i.user_agent, "user_agent_1")
+ self.assertEqual(i.last_seen, last_seen_1)
+ elif i.ip == "ip_2":
+ ip_2_seen = True
+ self.assertEqual(i.user_agent, "user_agent_2")
+ self.assertEqual(i.last_seen, last_seen_2)
+ self.assertTrue(ip_1_seen)
+ self.assertTrue(ip_2_seen)
+
+ # If we fetch from a midpoint between last_seen_1 and last_seen_2,
+ # we should only find the second ip, agent.
+ info = self.get_success(
+ self.module_api.get_user_ip_and_agents(
+ user_id, (last_seen_1 + last_seen_2) / 2
+ )
+ )
+ self.assertEqual(len(info), 1)
+ self.assertEqual(info[0].ip, "ip_2")
+ self.assertEqual(info[0].user_agent, "user_agent_2")
+ self.assertEqual(info[0].last_seen, last_seen_2)
+
+ # If we fetch from a point later than last_seen_2, we shouldn't
+ # find anything.
+ info = self.get_success(
+ self.module_api.get_user_ip_and_agents(user_id, last_seen_2 + 10000)
+ )
+ self.assertEqual(info, [])
+
+ def test_get_user_ip_and_agents__no_user_found(self):
+ info = self.get_success(
+ self.module_api.get_user_ip_and_agents(
+ "@test_get_user_ip_and_agents_user_nonexistent:example.com"
+ )
+ )
+ self.assertEqual(info, [])
+
def test_sending_events_into_room(self):
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index e9fd991718..c7555c26db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -328,7 +328,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up TCP replication between master and the new worker if we don't
# have Redis support enabled.
- if not worker_hs.config.redis_enabled:
+ if not worker_hs.config.redis.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs,
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 01b1b0d4a0..13aa5eb51a 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
f"/{target}/{media_id}",
shorthand=False,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index febd40b656..192073c520 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it."""
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
server_name_and_media_id,
shorthand=False,
@@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
server_and_media_id_2,
shorthand=False,
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 2f02934e72..ce30a19213 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -43,7 +43,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
def test_no_auth(self):
"""
@@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -200,7 +200,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
def test_no_auth(self):
@@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 4927321e5a..9bac423ae0 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -95,8 +95,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_specifying_fields(self):
"""Create a token specifying the value of all fields."""
+ # As many of the allowed characters as possible with length <= 64
+ token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-"
data = {
- "token": "abcd",
+ "token": token,
"uses_allowed": 1,
"expiry_time": self.clock.time_msec() + 1000000,
}
@@ -109,7 +111,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
@@ -193,7 +195,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
"""Check right error is raised when server can't generate unique token."""
# Create all possible single character tokens
tokens = []
- for c in string.ascii_letters + string.digits + "-_":
+ for c in string.ascii_letters + string.digits + "._~-":
tokens.append(
{
"token": c,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 40e032df7f..0fa55e03b4 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -47,7 +47,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
+ hs.config.consent.user_consent_version = "1"
consent_uri_builder = Mock()
consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
@@ -941,6 +941,33 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
+ def test_search_term_non_ascii(self):
+ """Test that searching for a room with non-ASCII characters works correctly"""
+
+ # Create test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_name = "ж"
+
+ # Set the name for the room
+ self.helper.send_state(
+ room_id,
+ "m.room.name",
+ {"name": room_name},
+ tok=self.admin_user_tok,
+ )
+
+ # make the request and test that the response is what we wanted
+ search_term = urllib.parse.quote("ж", "utf-8")
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
+ self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
+
def test_single_room(self):
"""Test that a single room can be requested correctly"""
# Create two test rooms
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cc3f16c62a..ee3ae9cce4 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2473,7 +2473,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts`
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..9e9e953cf4 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Load the password reset confirmation page
channel = make_request(
self.reactor,
- FakeSite(self.submit_token_resource),
+ FakeSite(self.submit_token_resource, self.reactor),
"GET",
path,
shorthand=False,
@@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Confirm the password reset
channel = make_request(
self.reactor,
- FakeSite(self.submit_token_resource),
+ FakeSite(self.submit_token_resource, self.reactor),
"POST",
path,
content=b"",
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 65c58ce70a..84d092ca82 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
channel = make_request(
- self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+ self.reactor,
+ FakeSite(resource, self.reactor),
+ "GET",
+ "/consent?v=1",
+ shorthand=False,
)
self.assertEqual(channel.code, 200)
@@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
)
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
@@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# POST to the consent page, saying we've agreed
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
@@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# changed
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f5c195a075..371615a015 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -97,7 +97,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
+ self.hs.config.captcha.enable_registration_captcha = False
return self.hs
@@ -815,9 +815,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.jwt_enabled = True
- self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = self.jwt_algorithm
+ self.hs.config.jwt.jwt_enabled = True
+ self.hs.config.jwt.jwt_secret = self.jwt_secret
+ self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
@@ -1023,9 +1023,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.jwt_enabled = True
- self.hs.config.jwt_secret = self.jwt_pubkey
- self.hs.config.jwt_algorithm = "RS256"
+ self.hs.config.jwt.jwt_enabled = True
+ self.hs.config.jwt.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt.jwt_algorithm = "RS256"
return self.hs
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 9f3ab2c985..72a5a11b46 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -146,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
- self.hs.config.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 50100a5ae4..30bdaa9c27 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Iterable
+from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -26,11 +26,11 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -377,6 +377,91 @@ class RoomPermissionsTestCase(RoomBase):
expect_code=403,
)
+ # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
+ def test_member_event_from_ban(self):
+ room = self.created_rmid
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
+
+ other = "@burgundy:red"
+
+ # User cannot ban other since they do not have required power level
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.BAN,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.FORBIDDEN,
+ )
+
+ # Admin bans other
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.BAN,
+ expect_code=200,
+ )
+
+ # from ban to invite: Must never happen.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.INVITE,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.BAD_STATE,
+ )
+
+ # from ban to join: Must never happen.
+ self.helper.change_membership(
+ room=room,
+ src=other,
+ targ=other,
+ membership=Membership.JOIN,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.BAD_STATE,
+ )
+
+ # from ban to ban: No change.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.BAN,
+ expect_code=200,
+ )
+
+ # from ban to knock: Must never happen.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.KNOCK,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.BAD_STATE,
+ )
+
+ # User cannot unban other since they do not have required power level
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.LEAVE,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.FORBIDDEN,
+ )
+
+ # from ban to leave: User was unbanned.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.LEAVE,
+ expect_code=200,
+ )
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -584,6 +669,121 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
+ def test_spamchecker_invites(self):
+ """Tests the user_may_create_room_with_invites spam checker callback."""
+
+ # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
+ # IS.
+ async def do_3pid_invite(
+ room_id: str,
+ inviter: UserID,
+ medium: str,
+ address: str,
+ id_server: str,
+ requester: Requester,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> int:
+ return 0
+
+ do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
+ self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
+
+ # Add a mock callback for user_may_create_room_with_invites. Make it allow any
+ # room creation request for now.
+ return_value = True
+
+ async def user_may_create_room_with_invites(
+ user: str,
+ invites: List[str],
+ threepid_invites: List[Dict[str, str]],
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_create_room_with_invites)
+ self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
+ callback_mock,
+ )
+
+ # The MXIDs we'll try to invite.
+ invited_mxids = [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+
+ # The 3PIDs we'll try to invite.
+ invited_3pids = [
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice1@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice2@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice3@example.com",
+ },
+ ]
+
+ # Create a room and invite the Matrix users, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite": invited_mxids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, invited_mxids, []),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Create a room and invite the 3PIDs, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that do_3pid_invite was called the right amount of time
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, [], invited_3pids),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now deny any room creation.
+ return_value = False
+
+ # Create a room and invite the 3PIDs, and check that it failed.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(403, channel.code)
+
+ # Check that do_3pid_invite wasn't called this time.
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -784,6 +984,12 @@ class RoomJoinRatelimitTestCase(RoomBase):
room.register_servlets,
]
+ def prepare(self, reactor, clock, homeserver):
+ super().prepare(reactor, clock, homeserver)
+ # profile changes expect that the user is actually registered
+ user = UserID.from_string(self.user_id)
+ self.get_success(self.register_user(user.localpart, "supersecretpassword"))
+
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
@@ -813,12 +1019,6 @@ class RoomJoinRatelimitTestCase(RoomBase):
# join in a second.
room_ids.append(self.helper.create_room_as(self.user_id))
- # Create a profile for the user, since it hasn't been done on registration.
- store = self.hs.get_datastore()
- self.get_success(
- store.create_profile(UserID.from_string(self.user_id).localpart)
- )
-
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
channel = self.make_request("PUT", path, {"displayname": "John Doe"})
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 6a0d9a82be..b0c44af033 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -193,7 +193,7 @@ class RoomTestCase(_ShadowBannedBase):
self.assertEquals(200, channel.code)
# There should be no typing events.
- event_source = self.hs.get_event_sources().sources["typing"]
+ event_source = self.hs.get_event_sources().sources.typing
self.assertEquals(event_source.get_current_key(), 0)
# The other user can join and send typing events.
@@ -210,7 +210,13 @@ class RoomTestCase(_ShadowBannedBase):
# These appear in the room.
self.assertEquals(event_source.get_current_key(), 1)
events = self.get_success(
- event_source.get_new_events(from_key=0, room_ids=[room_id])
+ event_source.get_new_events(
+ user=UserID.from_string(self.other_user_id),
+ from_key=0,
+ limit=None,
+ room_ids=[room_id],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0ae4029640..38ac9be113 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,6 +15,7 @@ import threading
from typing import Dict
from unittest.mock import Mock
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
@@ -327,3 +328,86 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
+
+ def test_sent_event_end_up_in_room_state(self):
+ """Tests that a state event sent by a module while processing another state event
+ doesn't get dropped from the state of the room. This is to guard against a bug
+ where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
+ """
+ event_type = "org.matrix.test_state"
+
+ # This content will be updated later on, and since we actually use a reference on
+ # the dict it does the right thing. It's a bit hacky but a handy way of making
+ # sure the state actually gets updated.
+ event_content = {"i": -1}
+
+ api = self.hs.get_module_api()
+
+ # Define a callback that sends a custom event on power levels update.
+ async def test_fn(event: EventBase, state_events):
+ if event.is_state and event.type == EventTypes.PowerLevels:
+ await api.create_and_send_event_into_room(
+ {
+ "room_id": event.room_id,
+ "sender": event.sender,
+ "type": event_type,
+ "content": event_content,
+ "state_key": "",
+ }
+ )
+ return True, None
+
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [test_fn]
+
+ # Sometimes the bug might not happen the first time the event type is added
+ # to the state but might happen when an event updates the state of the room for
+ # that type, so we test updating the state several times.
+ for i in range(5):
+ # Update the content of the custom state event to be sent by the callback.
+ event_content["i"] = i
+
+ # Update the room's power levels with a different value each time so Synapse
+ # doesn't consider an update redundant.
+ self._update_power_levels(event_default=i)
+
+ # Check that the new event made it to the room's state.
+ channel = self.make_request(
+ method="GET",
+ path="/rooms/" + self.room_id + "/state/" + event_type,
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["i"], i)
+
+ def _update_power_levels(self, event_default: int = 0):
+ """Updates the room's power levels.
+
+ Args:
+ event_default: Value to use for 'events_default'.
+ """
+ self.helper.send_state(
+ room_id=self.room_id,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.encryption": 100,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.server_acl": 100,
+ "m.room.tombstone": 100,
+ },
+ "events_default": event_default,
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "state_default": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ },
+ tok=self.tok,
+ )
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..ee0abd5295 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -41,7 +41,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
federation_client=Mock(),
)
- self.event_source = hs.get_event_sources().sources["typing"]
+ self.event_source = hs.get_event_sources().sources.typing
hs.get_federation_handler = Mock()
@@ -76,7 +76,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ self.event_source.get_new_events(
+ user=UserID.from_string(self.user_id),
+ from_key=0,
+ limit=None,
+ room_ids=[self.room_id],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 954ad1a1fd..3075d3f288 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -138,6 +138,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
+ expect_errcode: str = None,
) -> None:
"""
Send a membership state event into a room.
@@ -150,6 +151,7 @@ class RestHelper:
extra_data: Extra information to include in the content of the event
tok: The user access token to use
expect_code: The expected HTTP response code
+ expect_errcode: The expected Matrix error code
"""
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -177,6 +179,15 @@ class RestHelper:
channel.result["body"],
)
+ if expect_errcode:
+ assert (
+ str(channel.json_body["errcode"]) == expect_errcode
+ ), "Expected: %r, got: %r, resp: %r" % (
+ expect_errcode,
+ channel.json_body["errcode"],
+ channel.result["body"],
+ )
+
self.auth_user_id = temp_id
def send(
@@ -372,7 +383,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
self.hs.get_reactor(),
- FakeSite(resource),
+ FakeSite(resource, self.hs.get_reactor()),
"POST",
path,
content=image_data,
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index a75c0ea3f0..4672a68596 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel)
+ req = SynapseRequest(channel, self.site)
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel)
+ req = SynapseRequest(channel, self.site)
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 9ea1c2bf25..4ae00755c9 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -53,7 +53,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
- hs.config.media_store_path = self.primary_base_path
+ hs.config.media.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
@@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
self.media_id,
shorthand=False,
@@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = make_request(
self.reactor,
- FakeSite(self.thumbnail_resource),
+ FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
@@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
- FakeSite(self.thumbnail_resource),
+ FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
@@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
- FakeSite(self.thumbnail_resource),
+ FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 9f6fbfe6de..4d09b5d07e 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -21,9 +21,11 @@ from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.server import FakeTransport
+from tests.test_utils import SMALL_PNG
try:
import lxml
@@ -576,13 +578,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}
oembed_content = json.dumps(result).encode("utf-8")
- end_content = (
- b"<html><head>"
- b"<title>Some Title</title>"
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
-
channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
@@ -606,6 +601,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
+ # Ensure a second request is made to the photo URL.
client = self.reactor.tcpClients[1][2].buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
@@ -613,18 +609,24 @@ class URLPreviewTests(unittest.HomeserverTestCase):
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ b"Content-Type: image/png\r\n\r\n"
)
- % (len(end_content),)
- + end_content
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
)
self.pump()
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/matrixdotorg", server.data)
+
self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
- )
+ body = channel.json_body
+ self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
+ self.assertTrue(body["og:image"].startswith("mxc://"))
+ self.assertEqual(body["og:image:height"], 1)
+ self.assertEqual(body["og:image:width"], 1)
+ self.assertEqual(body["og:image:type"], "image/png")
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
@@ -633,6 +635,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
result = {
"version": "1.0",
"type": "rich",
+ # Note that this provides the author, not the title.
+ "author_name": "Alice",
"html": "<div>Content Preview</div>",
}
end_content = json.dumps(result).encode("utf-8")
@@ -660,9 +664,14 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
+ body = channel.json_body
self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ body,
+ {
+ "og:url": "http://twitter.com/matrixdotorg/status/12345",
+ "og:title": "Alice",
+ "og:description": "Content Preview",
+ },
)
def test_oembed_format(self):
@@ -705,7 +714,140 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIn(b"format=json", server.data)
self.assertEqual(channel.code, 200)
+ body = channel.json_body
self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ body,
+ {
+ "og:url": "http://www.hulu.com/watch/12345",
+ "og:description": "Content Preview",
+ },
+ )
+
+ def _download_image(self):
+ """Downloads an image into the URL cache.
+
+ Returns:
+ A (host, media_id) tuple representing the MXC URI of the image.
+ """
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://cdn.twitter.com/matrixdotorg",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n"
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ mxc_uri = body["og:image"]
+ host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri)
+ self.assertIsNone(_port)
+ return host, media_id
+
+ def test_storage_providers_exclude_files(self):
+ """Test that files are not stored in or fetched from storage providers."""
+ host, media_id = self._download_image()
+
+ rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id)
+ media_store_path = os.path.join(self.media_store_path, rel_file_path)
+ storage_provider_path = os.path.join(self.storage_path, rel_file_path)
+
+ # Check storage
+ self.assertTrue(os.path.isfile(media_store_path))
+ self.assertFalse(
+ os.path.isfile(storage_provider_path),
+ "URL cache file was unexpectedly stored in a storage provider",
+ )
+
+ # Check fetching
+ channel = self.make_request(
+ "GET",
+ f"download/{host}/{media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # Move cached file into the storage provider
+ os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True)
+ os.rename(media_store_path, storage_provider_path)
+
+ channel = self.make_request(
+ "GET",
+ f"download/{host}/{media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(
+ channel.code,
+ 404,
+ "URL cache file was unexpectedly retrieved from a storage provider",
+ )
+
+ def test_storage_providers_exclude_thumbnails(self):
+ """Test that thumbnails are not stored in or fetched from storage providers."""
+ host, media_id = self._download_image()
+
+ rel_thumbnail_path = (
+ self.preview_url.filepaths.url_cache_thumbnail_directory_rel(media_id)
+ )
+ media_store_thumbnail_path = os.path.join(
+ self.media_store_path, rel_thumbnail_path
+ )
+ storage_provider_thumbnail_path = os.path.join(
+ self.storage_path, rel_thumbnail_path
+ )
+
+ # Check storage
+ self.assertTrue(os.path.isdir(media_store_thumbnail_path))
+ self.assertFalse(
+ os.path.isdir(storage_provider_thumbnail_path),
+ "URL cache thumbnails were unexpectedly stored in a storage provider",
+ )
+
+ # Check fetching
+ channel = self.make_request(
+ "GET",
+ f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # Remove the original, otherwise thumbnails will regenerate
+ rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id)
+ media_store_path = os.path.join(self.media_store_path, rel_file_path)
+ os.remove(media_store_path)
+
+ # Move cached thumbnails into the storage provider
+ os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True)
+ os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path)
+
+ channel = self.make_request(
+ "GET",
+ f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(
+ channel.code,
+ 404,
+ "URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
diff --git a/tests/server.py b/tests/server.py
index b861c7b866..88dfa8058e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -19,6 +19,7 @@ from twisted.internet.interfaces import (
IPullProducer,
IPushProducer,
IReactorPluggableNameResolver,
+ IReactorTime,
IResolverSimple,
ITransport,
)
@@ -181,13 +182,14 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
- def __init__(self, resource: IResource):
+ def __init__(self, resource: IResource, reactor: IReactorTime):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
+ self.reactor = reactor
def getResourceFor(self, request):
return self._resource
@@ -268,7 +270,7 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip)
- req = request(channel)
+ req = request(channel, site)
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 8701b5f7e3..7f25200a5d 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -326,7 +326,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
for event in events:
if (
event["type"] == EventTypes.Message
- and event["sender"] == self.hs.config.server_notices_mxid
+ and event["sender"] == self.hs.config.servernotices.server_notices_mxid
):
notice_in_room = True
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 666bffe257..cf9748f218 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -41,9 +41,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.appservice.app_service_config_files = self.as_yaml_files
hs.config.caches.event_cache_size = 1
- hs.config.password_providers = []
self.as_token = "token1"
self.as_url = "some_url"
@@ -108,9 +107,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- hs.config.app_service_config_files = self.as_yaml_files
+ hs.config.appservice.app_service_config_files = self.as_yaml_files
hs.config.caches.event_cache_size = 1
- hs.config.password_providers = []
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -496,9 +494,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- hs.config.app_service_config_files = [f1, f2]
+ hs.config.appservice.app_service_config_files = [f1, f2]
hs.config.caches.event_cache_size = 1
- hs.config.password_providers = []
database = hs.get_datastores().databases[0]
ApplicationServiceStore(
@@ -514,9 +511,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- hs.config.app_service_config_files = [f1, f2]
+ hs.config.appservice.app_service_config_files = [f1, f2]
hs.config.caches.event_cache_size = 1
- hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
@@ -540,9 +536,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
- hs.config.app_service_config_files = [f1, f2]
+ hs.config.appservice.app_service_config_files = [f1, f2]
hs.config.caches.event_cache_size = 1
- hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
database = hs.get_datastores().databases[0]
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index da98733ce8..7cc5e621ba 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -258,7 +258,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler()
- homeserver.config.user_consent_version = self.CONSENT_VERSION
+ homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
def test_send_dummy_event(self):
self._create_extremity_rich_graph()
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
new file mode 100644
index 0000000000..8971ecccbd
--- /dev/null
+++ b/tests/storage/test_room_search.py
@@ -0,0 +1,74 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse.rest.admin
+from synapse.rest.client import login, room
+from synapse.storage.engines import PostgresEngine
+
+from tests.unittest import HomeserverTestCase
+
+
+class NullByteInsertionTest(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_null_byte(self):
+ """
+ Postgres/SQLite don't like null bytes going into the search tables. Internally
+ we replace those with a space.
+
+ Ensure this doesn't break anything.
+ """
+
+ # Register a user and create a room, create some messages
+ self.register_user("alice", "password")
+ access_token = self.login("alice", "password")
+ room_id = self.helper.create_room_as("alice", tok=access_token)
+
+ # Send messages and ensure they don't cause an internal server
+ # error
+ for body in ["hi\u0000bob", "another message", "hi alice"]:
+ response = self.helper.send(room_id, body, tok=access_token)
+ self.assertIn("event_id", response)
+
+ # Check that search works for the message where the null byte was replaced
+ store = self.hs.get_datastore()
+ result = self.get_success(
+ store.search_msgs([room_id], "hi bob", ["content.body"])
+ )
+ self.assertEquals(result.get("count"), 1)
+ if isinstance(store.database_engine, PostgresEngine):
+ self.assertIn("hi", result.get("highlights"))
+ self.assertIn("bob", result.get("highlights"))
+
+ # Check that search works for an unrelated message
+ result = self.get_success(
+ store.search_msgs([room_id], "another", ["content.body"])
+ )
+ self.assertEquals(result.get("count"), 1)
+ if isinstance(store.database_engine, PostgresEngine):
+ self.assertIn("another", result.get("highlights"))
+
+ # Check that search works for a search term that overlaps with the message
+ # containing a null byte and an unrelated message.
+ result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"]))
+ self.assertEquals(result.get("count"), 2)
+ result = self.get_success(
+ store.search_msgs([room_id], "hi alice", ["content.body"])
+ )
+ if isinstance(store.database_engine, PostgresEngine):
+ self.assertIn("alice", result.get("highlights"))
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c51e018da1..24fc77d7a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
event,
context,
state=None,
- claimed_auth_event_map=None,
backfilled=False,
):
return context
diff --git a/tests/test_server.py b/tests/test_server.py
index 407e172e41..f2ffbc895b 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase):
)
make_request(
- self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
+ self.reactor,
+ FakeSite(res, self.reactor),
+ b"GET",
+ b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
)
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
@@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
+ )
self.assertEqual(channel.result["code"], b"500")
@@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase):
def _callback(request, **kwargs):
d = Deferred()
d.addCallback(_throw)
- self.reactor.callLater(1, d.callback, True)
+ self.reactor.callLater(0.5, d.callback, True)
return make_deferred_yieldable(d)
res = JsonResource(self.homeserver)
@@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
+ )
self.assertEqual(channel.result["code"], b"500")
@@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
+ )
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
+ )
self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase):
)
# The path was registered as GET, but this is a HEAD request.
- channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
+ )
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
@@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
+ )
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
@@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
+ )
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
@@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
+ )
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
@@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
- channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
+ channel = make_request(
+ self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
+ )
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index e5488df1ac..76e0e8ca7f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -106,7 +106,7 @@ class StateGroupStore:
}
async def get_state_group_delta(self, name):
- return (None, None)
+ return None, None
def register_events(self, events):
for e in events:
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 34aaffe859..89d8656634 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -95,4 +95,4 @@ def build_rc_config(settings: Optional[dict] = None):
config_dict.update(settings or {})
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
- return config.rc_federation
+ return config.ratelimiting.rc_federation
diff --git a/tests/utils.py b/tests/utils.py
index f3458ca88d..cf8ba5c5db 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -434,7 +434,7 @@ class MockHttpResource:
)
return code, response
except CodeMessageException as e:
- return (e.code, cs_error(e.msg, code=e.errcode))
+ return e.code, cs_error(e.msg, code=e.errcode)
raise KeyError("No event can handle %s" % path)
|