diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index 5527e278db..d66aeb00eb 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -1,6 +1,6 @@
import synapse
from synapse.app.phone_stats_home import start_phone_stats_home
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests import unittest
from tests.unittest import HomeserverTestCase
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
index 84ae3b88ae..baa5313fb3 100644
--- a/tests/config/test_base.py
+++ b/tests/config/test_base.py
@@ -30,7 +30,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# contain template files
with tempfile.TemporaryDirectory() as tmp_dir:
# Attempt to load an HTML template from our custom template directory
- template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+ template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0]
# If no errors, we should've gotten the default template instead
@@ -60,7 +60,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
# Attempt to load the template from our custom template directory
template = (
- self.hs.config.read_templates([template_filename], tmp_dir)
+ self.hs.config.read_templates([template_filename], (tmp_dir,))
)[0]
# Render the template
@@ -74,8 +74,66 @@ class BaseConfigTestCase(unittest.HomeserverTestCase):
"Template file did not contain our test string",
)
+ def test_multiple_custom_template_directories(self):
+ """Tests that directories are searched in the right order if multiple custom
+ template directories are provided.
+ """
+ # Create two temporary directories on the filesystem.
+ tempdirs = [
+ tempfile.TemporaryDirectory(),
+ tempfile.TemporaryDirectory(),
+ ]
+
+ # Create one template in each directory, whose content is the index of the
+ # directory in the list.
+ template_filename = "my_template.html.j2"
+ for i in range(len(tempdirs)):
+ tempdir = tempdirs[i]
+ template_path = os.path.join(tempdir.name, template_filename)
+
+ with open(template_path, "w") as fp:
+ fp.write(str(i))
+ fp.flush()
+
+ # Retrieve the template.
+ template = (
+ self.hs.config.read_templates(
+ [template_filename],
+ (td.name for td in tempdirs),
+ )
+ )[0]
+
+ # Test that we got the template we dropped in the first directory in the list.
+ self.assertEqual(template.render(), "0")
+
+ # Add another template, this one only in the second directory in the list, so we
+ # can test that the second directory is still searched into when no matching file
+ # could be found in the first one.
+ other_template_name = "my_other_template.html.j2"
+ other_template_path = os.path.join(tempdirs[1].name, other_template_name)
+
+ with open(other_template_path, "w") as fp:
+ fp.write("hello world")
+ fp.flush()
+
+ # Retrieve the template.
+ template = (
+ self.hs.config.read_templates(
+ [other_template_name],
+ (td.name for td in tempdirs),
+ )
+ )[0]
+
+ # Test that the file has the expected content.
+ self.assertEqual(template.render(), "hello world")
+
+ # Cleanup the temporary directories manually since we're not using a context
+ # manager.
+ for td in tempdirs:
+ td.cleanup()
+
def test_loading_template_from_nonexistent_custom_directory(self):
with self.assertRaises(ConfigError):
self.hs.config.read_templates(
- ["some_filename.html"], "a_nonexistent_directory"
+ ["some_filename.html"], ("a_nonexistent_directory",)
)
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index 6f2b9e997d..b6f21294ba 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -35,7 +35,7 @@ class ServerConfigTestCase(unittest.TestCase):
def test_unsecure_listener_no_listeners_open_private_ports_false(self):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "che.org", "/data_dir_path", False, None
+ "che.org", "/data_dir_path", False, None, config_dir_path="CONFDIR"
)
)
@@ -55,7 +55,7 @@ class ServerConfigTestCase(unittest.TestCase):
def test_unsecure_listener_no_listeners_open_private_ports_true(self):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "che.org", "/data_dir_path", True, None
+ "che.org", "/data_dir_path", True, None, config_dir_path="CONFDIR"
)
)
@@ -89,7 +89,7 @@ class ServerConfigTestCase(unittest.TestCase):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "this.one.listens", "/data_dir_path", True, listeners
+ "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR"
)
)
@@ -123,7 +123,7 @@ class ServerConfigTestCase(unittest.TestCase):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "this.one.listens", "/data_dir_path", True, listeners
+ "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR"
)
)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 3f41e99950..3b3866bff8 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -17,12 +17,12 @@ from unittest.mock import Mock
import attr
from synapse.api.constants import EduTypes
-from synapse.events.presence_router import PresenceRouter
+from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
from synapse.rest import admin
-from synapse.rest.client.v1 import login, presence, room
+from synapse.rest.client import login, presence, room
from synapse.types import JsonDict, StreamToken, create_requester
from tests.handlers.test_sync import generate_sync_config
@@ -34,7 +34,7 @@ class PresenceRouterTestConfig:
users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
-class PresenceRouterTestModule:
+class LegacyPresenceRouterTestModule:
def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
self._config = config
self._module_api = module_api
@@ -77,6 +77,53 @@ class PresenceRouterTestModule:
return config
+class PresenceRouterTestModule:
+ def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi):
+ self._config = config
+ self._module_api = api
+ api.register_presence_router_callbacks(
+ get_users_for_states=self.get_users_for_states,
+ get_interested_users=self.get_interested_users,
+ )
+
+ async def get_users_for_states(
+ self, state_updates: Iterable[UserPresenceState]
+ ) -> Dict[str, Set[UserPresenceState]]:
+ users_to_state = {
+ user_id: set(state_updates)
+ for user_id in self._config.users_who_should_receive_all_presence
+ }
+ return users_to_state
+
+ async def get_interested_users(
+ self, user_id: str
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ if user_id in self._config.users_who_should_receive_all_presence:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterTestConfig()
+
+ config.users_who_should_receive_all_presence = config_dict.get(
+ "users_who_should_receive_all_presence"
+ )
+
+ return config
+
+
class PresenceRouterTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- return self.setup_test_homeserver(
+ hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
+ # Load the modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+
+ load_legacy_presence_router(hs)
+
+ return hs
def prepare(self, reactor, clock, homeserver):
self.sync_handler = self.hs.get_sync_handler()
@@ -98,7 +153,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
{
"presence": {
"presence_router": {
- "module": __name__ + ".PresenceRouterTestModule",
+ "module": __name__ + ".LegacyPresenceRouterTestModule",
"config": {
"users_who_should_receive_all_presence": [
"@presence_gobbler:test",
@@ -109,7 +164,28 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
"send_federation": True,
}
)
+ def test_receiving_all_presence_legacy(self):
+ self.receiving_all_presence_test_body()
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler:test",
+ ]
+ },
+ },
+ ],
+ "send_federation": True,
+ }
+ )
def test_receiving_all_presence(self):
+ self.receiving_all_presence_test_body()
+
+ def receiving_all_presence_test_body(self):
"""Test that a user that does not share a room with another other can receive
presence for them, due to presence routing.
"""
@@ -203,7 +279,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
{
"presence": {
"presence_router": {
- "module": __name__ + ".PresenceRouterTestModule",
+ "module": __name__ + ".LegacyPresenceRouterTestModule",
"config": {
"users_who_should_receive_all_presence": [
"@presence_gobbler1:test",
@@ -216,7 +292,30 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
"send_federation": True,
}
)
+ def test_send_local_online_presence_to_with_module_legacy(self):
+ self.send_local_online_presence_to_with_module_test_body()
+
+ @override_config(
+ {
+ "modules": [
+ {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler1:test",
+ "@presence_gobbler2:test",
+ "@far_away_person:island",
+ ]
+ },
+ },
+ ],
+ "send_federation": True,
+ }
+ )
def test_send_local_online_presence_to_with_module(self):
+ self.send_local_online_presence_to_with_module_test_body()
+
+ def send_local_online_presence_to_with_module_test_body(self):
"""Tests that send_local_presence_to_users sends local online presence to a set
of specified local and remote users, with a custom PresenceRouter module enabled.
"""
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 48e98aac79..ca27388ae8 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -14,7 +14,7 @@
from synapse.events.snapshot import EventContext
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests import unittest
from tests.test_utils.event_injection import create_event
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index e2a5fc018c..5446fda5e7 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -322,7 +322,7 @@ class PruneEventTestCase(unittest.TestCase):
},
)
- # After MSC3083, alias events have no special behavior.
+ # After MSC3083, the allow key is protected from redaction.
self.run_test(
{
"type": "m.room.join_rules",
@@ -341,7 +341,51 @@ class PruneEventTestCase(unittest.TestCase):
"signatures": {},
"unsigned": {},
},
- room_version=RoomVersions.MSC3083,
+ room_version=RoomVersions.V8,
+ )
+
+ def test_member(self):
+ """Member events have changed behavior starting with MSC3375."""
+ self.run_test(
+ {
+ "type": "m.room.member",
+ "event_id": "$test:domain",
+ "content": {
+ "membership": "join",
+ "join_authorised_via_users_server": "@user:domain",
+ "other_key": "stripped",
+ },
+ },
+ {
+ "type": "m.room.member",
+ "event_id": "$test:domain",
+ "content": {"membership": "join"},
+ "signatures": {},
+ "unsigned": {},
+ },
+ )
+
+ # After MSC3375, the join_authorised_via_users_server key is protected
+ # from redaction.
+ self.run_test(
+ {
+ "type": "m.room.member",
+ "content": {
+ "membership": "join",
+ "join_authorised_via_users_server": "@user:domain",
+ "other_key": "stripped",
+ },
+ },
+ {
+ "type": "m.room.member",
+ "content": {
+ "membership": "join",
+ "join_authorised_via_users_server": "@user:domain",
+ },
+ "signatures": {},
+ "unsigned": {},
+ },
+ room_version=RoomVersions.V9,
)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 1a809b2a6a..7b486aba4a 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -16,7 +16,7 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.types import UserID
from tests import unittest
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 802c5ad299..f0aa8ed9db 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -6,7 +6,7 @@ from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b00dd143d6..65b18fbd7a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.rest import admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client import login
from synapse.types import JsonDict, ReadReceipt
from tests.test_utils import make_awaitable
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 1737891564..0b60cc4261 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -19,7 +19,7 @@ from parameterized import parameterized
from synapse.events import make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests import unittest
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index aab44bce4a..663960ff53 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import builder
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import RoomAlias
@@ -208,7 +208,7 @@ class FederationKnockingTestCase(
async def _check_event_auth(origin, event, context, *args, **kwargs):
return context
- homeserver.get_federation_handler()._check_event_auth = _check_event_auth
+ homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth
return super().prepare(reactor, clock, homeserver)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 18a734daf4..59de1142b1 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -15,12 +15,10 @@
from collections import Counter
from unittest.mock import Mock
-import synapse.api.errors
-import synapse.handlers.admin
import synapse.rest.admin
import synapse.storage
from synapse.api.constants import EventTypes
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests import unittest
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 024c5e963c..43998020b2 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.assertEquals(result.room_id, room_id)
self.assertEquals(result.servers, servers)
- def _mkservice(self, is_interested):
+ def test_get_3pe_protocols_no_appservices(self):
+ self.mock_store.get_app_services.return_value = []
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
+ )
+ self.mock_as_api.get_3pe_protocol.assert_not_called()
+ self.assertEquals(response, {})
+
+ def test_get_3pe_protocols_no_protocols(self):
+ service = self._mkservice(False, [])
+ self.mock_store.get_app_services.return_value = [service]
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols())
+ )
+ self.mock_as_api.get_3pe_protocol.assert_not_called()
+ self.assertEquals(response, {})
+
+ def test_get_3pe_protocols_protocol_no_response(self):
+ service = self._mkservice(False, ["my-protocol"])
+ self.mock_store.get_app_services.return_value = [service]
+ self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols())
+ )
+ self.mock_as_api.get_3pe_protocol.assert_called_once_with(
+ service, "my-protocol"
+ )
+ self.assertEquals(response, {})
+
+ def test_get_3pe_protocols_select_one_protocol(self):
+ service = self._mkservice(False, ["my-protocol"])
+ self.mock_store.get_app_services.return_value = [service]
+ self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
+ {"x-protocol-data": 42, "instances": []}
+ )
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
+ )
+ self.mock_as_api.get_3pe_protocol.assert_called_once_with(
+ service, "my-protocol"
+ )
+ self.assertEquals(
+ response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
+ )
+
+ def test_get_3pe_protocols_one_protocol(self):
+ service = self._mkservice(False, ["my-protocol"])
+ self.mock_store.get_app_services.return_value = [service]
+ self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
+ {"x-protocol-data": 42, "instances": []}
+ )
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols())
+ )
+ self.mock_as_api.get_3pe_protocol.assert_called_once_with(
+ service, "my-protocol"
+ )
+ self.assertEquals(
+ response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
+ )
+
+ def test_get_3pe_protocols_multiple_protocol(self):
+ service_one = self._mkservice(False, ["my-protocol"])
+ service_two = self._mkservice(False, ["other-protocol"])
+ self.mock_store.get_app_services.return_value = [service_one, service_two]
+ self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
+ {"x-protocol-data": 42, "instances": []}
+ )
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols())
+ )
+ self.mock_as_api.get_3pe_protocol.assert_called()
+ self.assertEquals(
+ response,
+ {
+ "my-protocol": {"x-protocol-data": 42, "instances": []},
+ "other-protocol": {"x-protocol-data": 42, "instances": []},
+ },
+ )
+
+ def test_get_3pe_protocols_multiple_info(self):
+ service_one = self._mkservice(False, ["my-protocol"])
+ service_two = self._mkservice(False, ["my-protocol"])
+
+ async def get_3pe_protocol(service, unusedProtocol):
+ if service == service_one:
+ return {
+ "x-protocol-data": 42,
+ "instances": [{"desc": "Alice's service"}],
+ }
+ if service == service_two:
+ return {
+ "x-protocol-data": 36,
+ "x-not-used": 45,
+ "instances": [{"desc": "Bob's service"}],
+ }
+ raise Exception("Unexpected service")
+
+ self.mock_store.get_app_services.return_value = [service_one, service_two]
+ self.mock_as_api.get_3pe_protocol = get_3pe_protocol
+ response = self.successResultOf(
+ defer.ensureDeferred(self.handler.get_3pe_protocols())
+ )
+ # It's expected that the second service's data doesn't appear in the response
+ self.assertEquals(
+ response,
+ {
+ "my-protocol": {
+ "x-protocol-data": 42,
+ "instances": [
+ {
+ "desc": "Alice's service",
+ },
+ {"desc": "Bob's service"},
+ ],
+ },
+ },
+ )
+
+ def _mkservice(self, is_interested, protocols=None):
service = Mock()
service.is_interested.return_value = make_awaitable(is_interested)
service.token = "mock_service_token"
service.url = "mock_service_url"
+ service.protocols = protocols
return service
def _mkservice_alias(self, is_interested_in_alias):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 7a8041ab44..a0a48b564e 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -19,7 +19,7 @@ import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
from synapse.config.room_directory import RoomDirectoryConfig
-from synapse.rest.client.v1 import directory, login, room
+from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias, create_requester
from tests import unittest
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 4140fcefc2..6c67a16de9 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -22,7 +22,7 @@ from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.util.stringutils import random_string
from tests import unittest
@@ -130,7 +130,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -182,7 +184,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -311,7 +315,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
d = run_in_background(
- self.handler.on_receive_pdu, OTHER_SERVER, message_event
+ self.hs.get_federation_event_handler().on_receive_pdu,
+ OTHER_SERVER,
+ message_event,
)
self.get_success(d)
@@ -382,7 +388,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
join_event.signatures[other_server] = {"x": "y"}
with LoggingContext("send_join"):
d = run_in_background(
- self.handler.on_send_membership_event, other_server, join_event
+ self.hs.get_federation_event_handler().on_send_membership_event,
+ other_server,
+ join_event,
)
self.get_success(d)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index a8a9fc5b62..8a8d369fac 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.types import create_requester
from synapse.util.stringutils import random_string
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 32651db096..38e6d9f536 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,8 +20,7 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse
-from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client import devices, login
from synapse.types import JsonDict
from tests import unittest
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 18e92e90d7..671dc7d083 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Optional
from unittest.mock import Mock, call
from signedjson.key import generate_signing_key
@@ -33,7 +33,7 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest import admin
-from synapse.rest.client.v1 import room
+from synapse.rest.client import room
from synapse.types import UserID, get_domain_from_id
from tests import unittest
@@ -339,8 +339,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
class PresenceTimeoutTestCase(unittest.TestCase):
+ """Tests different timers and that the timer does not change `status_msg` of user."""
+
def test_idle_timer(self):
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -348,12 +351,14 @@ class PresenceTimeoutTestCase(unittest.TestCase):
state=PresenceState.ONLINE,
last_active_ts=now - IDLE_TIMER - 1,
last_user_sync_ts=now,
+ status_msg=status_msg,
)
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
+ self.assertEquals(new_state.status_msg, status_msg)
def test_busy_no_idle(self):
"""
@@ -361,6 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
presence state into unavailable.
"""
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -368,15 +374,18 @@ class PresenceTimeoutTestCase(unittest.TestCase):
state=PresenceState.BUSY,
last_active_ts=now - IDLE_TIMER - 1,
last_user_sync_ts=now,
+ status_msg=status_msg,
)
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.BUSY)
+ self.assertEquals(new_state.status_msg, status_msg)
def test_sync_timeout(self):
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -384,15 +393,18 @@ class PresenceTimeoutTestCase(unittest.TestCase):
state=PresenceState.ONLINE,
last_active_ts=0,
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
+ status_msg=status_msg,
)
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.OFFLINE)
+ self.assertEquals(new_state.status_msg, status_msg)
def test_sync_online(self):
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -400,6 +412,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
state=PresenceState.ONLINE,
last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1,
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
+ status_msg=status_msg,
)
new_state = handle_timeout(
@@ -408,9 +421,11 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.ONLINE)
+ self.assertEquals(new_state.status_msg, status_msg)
def test_federation_ping(self):
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -419,12 +434,13 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_active_ts=now,
last_user_sync_ts=now,
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
+ status_msg=status_msg,
)
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
- self.assertEquals(new_state, new_state)
+ self.assertEquals(state, new_state)
def test_no_timeout(self):
user_id = "@foo:bar"
@@ -444,6 +460,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
def test_federation_timeout(self):
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -452,6 +469,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_active_ts=now,
last_user_sync_ts=now,
last_federation_update_ts=now - FEDERATION_TIMEOUT - 1,
+ status_msg=status_msg,
)
new_state = handle_timeout(
@@ -460,9 +478,11 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.OFFLINE)
+ self.assertEquals(new_state.status_msg, status_msg)
def test_last_active(self):
user_id = "@foo:bar"
+ status_msg = "I'm here!"
now = 5000000
state = UserPresenceState.default(user_id)
@@ -471,6 +491,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1,
last_user_sync_ts=now,
last_federation_update_ts=now,
+ status_msg=status_msg,
)
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
@@ -516,6 +537,144 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(state.state, PresenceState.OFFLINE)
+ def test_user_goes_offline_by_timeout_status_msg_remain(self):
+ """Test that if a user doesn't update the records for a while
+ users presence goes `OFFLINE` because of timeout and `status_msg` remains.
+ """
+ user_id = "@test:server"
+ status_msg = "I'm here!"
+
+ # Mark user as online
+ self._set_presencestate_with_status_msg(
+ user_id, PresenceState.ONLINE, status_msg
+ )
+
+ # Check that if we wait a while without telling the handler the user has
+ # stopped syncing that their presence state doesn't get timed out.
+ self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2)
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.ONLINE)
+ self.assertEqual(state.status_msg, status_msg)
+
+ # Check that if the timeout fires, then the syncing user gets timed out
+ self.reactor.advance(SYNC_ONLINE_TIMEOUT)
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ # status_msg should remain even after going offline
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+ self.assertEqual(state.status_msg, status_msg)
+
+ def test_user_goes_offline_manually_with_no_status_msg(self):
+ """Test that if a user change presence manually to `OFFLINE`
+ and no status is set, that `status_msg` is `None`.
+ """
+ user_id = "@test:server"
+ status_msg = "I'm here!"
+
+ # Mark user as online
+ self._set_presencestate_with_status_msg(
+ user_id, PresenceState.ONLINE, status_msg
+ )
+
+ # Mark user as offline
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string(user_id), {"presence": PresenceState.OFFLINE}
+ )
+ )
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+ self.assertEqual(state.status_msg, None)
+
+ def test_user_goes_offline_manually_with_status_msg(self):
+ """Test that if a user change presence manually to `OFFLINE`
+ and a status is set, that `status_msg` appears.
+ """
+ user_id = "@test:server"
+ status_msg = "I'm here!"
+
+ # Mark user as online
+ self._set_presencestate_with_status_msg(
+ user_id, PresenceState.ONLINE, status_msg
+ )
+
+ # Mark user as offline
+ self._set_presencestate_with_status_msg(
+ user_id, PresenceState.OFFLINE, "And now here."
+ )
+
+ def test_user_reset_online_with_no_status(self):
+ """Test that if a user set again the presence manually
+ and no status is set, that `status_msg` is `None`.
+ """
+ user_id = "@test:server"
+ status_msg = "I'm here!"
+
+ # Mark user as online
+ self._set_presencestate_with_status_msg(
+ user_id, PresenceState.ONLINE, status_msg
+ )
+
+ # Mark user as online again
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string(user_id), {"presence": PresenceState.ONLINE}
+ )
+ )
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ # status_msg should remain even after going offline
+ self.assertEqual(state.state, PresenceState.ONLINE)
+ self.assertEqual(state.status_msg, None)
+
+ def test_set_presence_with_status_msg_none(self):
+ """Test that if a user set again the presence manually
+ and status is `None`, that `status_msg` is `None`.
+ """
+ user_id = "@test:server"
+ status_msg = "I'm here!"
+
+ # Mark user as online
+ self._set_presencestate_with_status_msg(
+ user_id, PresenceState.ONLINE, status_msg
+ )
+
+ # Mark user as online and `status_msg = None`
+ self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
+
+ def _set_presencestate_with_status_msg(
+ self, user_id: str, state: PresenceState, status_msg: Optional[str]
+ ):
+ """Set a PresenceState and status_msg and check the result.
+
+ Args:
+ user_id: User for that the status is to be set.
+ PresenceState: The new PresenceState.
+ status_msg: Status message that is to be set.
+ """
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string(user_id),
+ {"presence": state, "status_msg": status_msg},
+ )
+ )
+
+ new_state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(new_state.state, state)
+ self.assertEqual(new_state.status_msg, status_msg)
+
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
@@ -726,7 +885,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
@@ -867,7 +1026,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
)
- self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+ self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 93a9a084b2..732a12c9bd 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -286,6 +286,29 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
+ def test_handles_string_data(self):
+ """
+ Tests that an invalid shape for read-receipts is handled.
+ Context: https://github.com/matrix-org/synapse/issues/10603
+ """
+
+ self._test_filters_hidden(
+ [
+ {
+ "content": {
+ "$14356419edgd14394fHBLK:matrix.org": {
+ "m.read": {
+ "@rikj:jki.re": "string",
+ }
+ },
+ },
+ "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+ "type": "m.receipt",
+ },
+ ],
+ [],
+ )
+
def _test_filters_hidden(
self, events: List[JsonDict], expected_output: List[JsonDict]
):
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
new file mode 100644
index 0000000000..fcde5dab72
--- /dev/null
+++ b/tests/handlers/test_room.py
@@ -0,0 +1,108 @@
+import synapse
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms
+from synapse.rest.client import login, room
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+ def test_encrypted_by_default_config_option_all(self):
+ """Tests that invite-only and non-invite-only rooms have encryption enabled by
+ default when the config option encryption_enabled_by_default_for_room_type is "all".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+ def test_encrypted_by_default_config_option_invite(self):
+ """Tests that only new, invite-only rooms have encryption enabled by default when
+ the config option encryption_enabled_by_default_for_room_type is "invite".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+ def test_encrypted_by_default_config_option_off(self):
+ """Tests that neither new invite-only nor non-invite-only rooms have encryption
+ enabled by default when the config option
+ encryption_enabled_by_default_for_room_type is "off".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
new file mode 100644
index 0000000000..d3d0bf1ac5
--- /dev/null
+++ b/tests/handlers/test_room_summary.py
@@ -0,0 +1,992 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Iterable, List, Optional, Tuple
+from unittest import mock
+
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ HistoryVisibility,
+ JoinRules,
+ Membership,
+ RestrictedJoinRuleTypes,
+ RoomTypes,
+)
+from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.api.room_versions import RoomVersions
+from synapse.events import make_event_from_dict
+from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID
+
+from tests import unittest
+
+
+def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
+ result = mock.Mock(name=room_id)
+ result.room_id = room_id
+ result.content = {}
+ result.origin_server_ts = origin_server_ts
+ if order is not None:
+ result.content["order"] = order
+ return result
+
+
+def _order(*events):
+ return sorted(events, key=_child_events_comparison_key)
+
+
+class TestSpaceSummarySort(unittest.TestCase):
+ def test_no_order_last(self):
+ """An event with no ordering is placed behind those with an ordering."""
+ ev1 = _create_event("!abc:test")
+ ev2 = _create_event("!xyz:test", "xyz")
+
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ def test_order(self):
+ """The ordering should be used."""
+ ev1 = _create_event("!abc:test", "xyz")
+ ev2 = _create_event("!xyz:test", "abc")
+
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ def test_order_origin_server_ts(self):
+ """Origin server is a tie-breaker for ordering."""
+ ev1 = _create_event("!abc:test", origin_server_ts=10)
+ ev2 = _create_event("!xyz:test", origin_server_ts=30)
+
+ self.assertEqual([ev1, ev2], _order(ev1, ev2))
+
+ def test_order_room_id(self):
+ """Room ID is a final tie-breaker for ordering."""
+ ev1 = _create_event("!abc:test")
+ ev2 = _create_event("!xyz:test")
+
+ self.assertEqual([ev1, ev2], _order(ev1, ev2))
+
+ def test_invalid_ordering_type(self):
+ """Invalid orderings are considered the same as missing."""
+ ev1 = _create_event("!abc:test", 1)
+ ev2 = _create_event("!xyz:test", "xyz")
+
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ ev1 = _create_event("!abc:test", {})
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ ev1 = _create_event("!abc:test", [])
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ ev1 = _create_event("!abc:test", True)
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ def test_invalid_ordering_value(self):
+ """Invalid orderings are considered the same as missing."""
+ ev1 = _create_event("!abc:test", "foo\n")
+ ev2 = _create_event("!xyz:test", "xyz")
+
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+ ev1 = _create_event("!abc:test", "a" * 51)
+ self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+
+class SpaceSummaryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
+ self.hs = hs
+ self.handler = self.hs.get_room_summary_handler()
+
+ # Create a user.
+ self.user = self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ # Create a space and a child room.
+ self.space = self.helper.create_room_as(
+ self.user,
+ tok=self.token,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ self.room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, self.room, self.token)
+
+ def _add_child(
+ self, space_id: str, room_id: str, token: str, order: Optional[str] = None
+ ) -> None:
+ """Add a child room to a space."""
+ content: JsonDict = {"via": [self.hs.hostname]}
+ if order is not None:
+ content["order"] = order
+ self.helper.send_state(
+ space_id,
+ event_type=EventTypes.SpaceChild,
+ body=content,
+ tok=token,
+ state_key=room_id,
+ )
+
+ def _assert_rooms(
+ self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]]
+ ) -> None:
+ """
+ Assert that the expected room IDs and events are in the response.
+
+ Args:
+ result: The result from the API call.
+ rooms_and_children: An iterable of tuples where each tuple is:
+ The expected room ID.
+ The expected IDs of any children rooms.
+ """
+ room_ids = []
+ children_ids = []
+ for room_id, children in rooms_and_children:
+ room_ids.append(room_id)
+ if children:
+ children_ids.extend([(room_id, child_id) for child_id in children])
+ self.assertCountEqual(
+ [room.get("room_id") for room in result["rooms"]], room_ids
+ )
+ self.assertCountEqual(
+ [
+ (event.get("room_id"), event.get("state_key"))
+ for event in result["events"]
+ ],
+ children_ids,
+ )
+
+ def _assert_hierarchy(
+ self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]]
+ ) -> None:
+ """
+ Assert that the expected room IDs are in the response.
+
+ Args:
+ result: The result from the API call.
+ rooms_and_children: An iterable of tuples where each tuple is:
+ The expected room ID.
+ The expected IDs of any children rooms.
+ """
+ result_room_ids = []
+ result_children_ids = []
+ for result_room in result["rooms"]:
+ result_room_ids.append(result_room["room_id"])
+ result_children_ids.append(
+ [
+ (cs["room_id"], cs["state_key"])
+ for cs in result_room.get("children_state")
+ ]
+ )
+
+ room_ids = []
+ children_ids = []
+ for room_id, children in rooms_and_children:
+ room_ids.append(room_id)
+ children_ids.append([(room_id, child_id) for child_id in children])
+
+ # Note that order matters.
+ self.assertEqual(result_room_ids, room_ids)
+ self.assertEqual(result_children_ids, children_ids)
+
+ def _poke_fed_invite(self, room_id: str, from_user: str) -> None:
+ """
+ Creates a invite (as if received over federation) for the room from the
+ given hostname.
+
+ Args:
+ room_id: The room ID to issue an invite for.
+ fed_hostname: The user to invite from.
+ """
+ # Poke an invite over federation into the database.
+ fed_handler = self.hs.get_federation_handler()
+ fed_hostname = UserID.from_string(from_user).domain
+ event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "event_id": "!abcd:" + fed_hostname,
+ "type": EventTypes.Member,
+ "sender": from_user,
+ "state_key": self.user,
+ "content": {"membership": Membership.INVITE},
+ "prev_events": [],
+ "auth_events": [],
+ "depth": 1,
+ "origin_server_ts": 1234,
+ }
+ )
+ self.get_success(
+ fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
+ )
+
+ def test_simple_space(self):
+ """Test a simple space with a single room."""
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The result should have the space and the room in it, along with a link
+ # from space -> room.
+ expected = [(self.space, [self.room]), (self.room, ())]
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
+ def test_visibility(self):
+ """A user not in a space cannot inspect it."""
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+
+ # The user can see the space since it is publicly joinable.
+ result = self.get_success(self.handler.get_space_summary(user2, self.space))
+ expected = [(self.space, [self.room]), (self.room, ())]
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ self._assert_hierarchy(result, expected)
+
+ # If the space is made invite-only, it should no longer be viewable.
+ self.helper.send_state(
+ self.space,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.INVITE},
+ tok=self.token,
+ )
+ self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
+ self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
+
+ # If the space is made world-readable it should return a result.
+ self.helper.send_state(
+ self.space,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+ tok=self.token,
+ )
+ result = self.get_success(self.handler.get_space_summary(user2, self.space))
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ self._assert_hierarchy(result, expected)
+
+ # Make it not world-readable again and confirm it results in an error.
+ self.helper.send_state(
+ self.space,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.JOINED},
+ tok=self.token,
+ )
+ self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
+ self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
+
+ # Join the space and results should be returned.
+ self.helper.invite(self.space, targ=user2, tok=self.token)
+ self.helper.join(self.space, user2, tok=token2)
+ result = self.get_success(self.handler.get_space_summary(user2, self.space))
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ self._assert_hierarchy(result, expected)
+
+ # Attempting to view an unknown room returns the same error.
+ self.get_failure(
+ self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname),
+ AuthError,
+ )
+ self.get_failure(
+ self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname),
+ AuthError,
+ )
+
+ def _create_room_with_join_rule(
+ self, join_rule: str, room_version: Optional[str] = None, **extra_content
+ ) -> str:
+ """Create a room with the given join rule and add it to the space."""
+ room_id = self.helper.create_room_as(
+ self.user,
+ room_version=room_version,
+ tok=self.token,
+ extra_content={
+ "initial_state": [
+ {
+ "type": EventTypes.JoinRules,
+ "state_key": "",
+ "content": {
+ "join_rule": join_rule,
+ **extra_content,
+ },
+ }
+ ]
+ },
+ )
+ self._add_child(self.space, room_id, self.token)
+ return room_id
+
+ def test_filtering(self):
+ """
+ Rooms should be properly filtered to only include rooms the user has access to.
+ """
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+
+ # Create a few rooms which will have different properties.
+ public_room = self._create_room_with_join_rule(JoinRules.PUBLIC)
+ knock_room = self._create_room_with_join_rule(
+ JoinRules.KNOCK, room_version=RoomVersions.V7.identifier
+ )
+ not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ self.helper.invite(invited_room, targ=user2, tok=self.token)
+ restricted_room = self._create_room_with_join_rule(
+ JoinRules.RESTRICTED,
+ room_version=RoomVersions.V8.identifier,
+ allow=[],
+ )
+ restricted_accessible_room = self._create_room_with_join_rule(
+ JoinRules.RESTRICTED,
+ room_version=RoomVersions.V8.identifier,
+ allow=[
+ {
+ "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP,
+ "room_id": self.space,
+ "via": [self.hs.hostname],
+ }
+ ],
+ )
+ world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ self.helper.send_state(
+ world_readable_room,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+ tok=self.token,
+ )
+ joined_room = self._create_room_with_join_rule(JoinRules.INVITE)
+ self.helper.invite(joined_room, targ=user2, tok=self.token)
+ self.helper.join(joined_room, user2, tok=token2)
+
+ # Join the space.
+ self.helper.join(self.space, user2, tok=token2)
+ result = self.get_success(self.handler.get_space_summary(user2, self.space))
+ expected = [
+ (
+ self.space,
+ [
+ self.room,
+ public_room,
+ knock_room,
+ not_invited_room,
+ invited_room,
+ restricted_room,
+ restricted_accessible_room,
+ world_readable_room,
+ joined_room,
+ ],
+ ),
+ (self.room, ()),
+ (public_room, ()),
+ (knock_room, ()),
+ (invited_room, ()),
+ (restricted_accessible_room, ()),
+ (world_readable_room, ()),
+ (joined_room, ()),
+ ]
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+ self._assert_hierarchy(result, expected)
+
+ def test_complex_space(self):
+ """
+ Create a "complex" space to see how it handles things like loops and subspaces.
+ """
+ # Create an inaccessible room.
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+ room2 = self.helper.create_room_as(user2, is_public=False, tok=token2)
+ # This is a bit odd as "user" is adding a room they don't know about, but
+ # it works for the tests.
+ self._add_child(self.space, room2, self.token)
+
+ # Create a subspace under the space with an additional room in it.
+ subspace = self.helper.create_room_as(
+ self.user,
+ tok=self.token,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ subroom = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, subspace, token=self.token)
+ self._add_child(subspace, subroom, token=self.token)
+ # Also add the two rooms from the space into this subspace (causing loops).
+ self._add_child(subspace, self.room, token=self.token)
+ self._add_child(subspace, room2, self.token)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+
+ # The result should include each room a single time and each link.
+ expected = [
+ (self.space, [self.room, room2, subspace]),
+ (self.room, ()),
+ (subspace, [subroom, self.room, room2]),
+ (subroom, ()),
+ ]
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
+ def test_pagination(self):
+ """Test simple pagination works."""
+ room_ids = []
+ for i in range(1, 10):
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, room, self.token, order=str(i))
+ room_ids.append(room)
+ # The room created initially doesn't have an order, so comes last.
+ room_ids.append(self.room)
+
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+ )
+ # The result should have the space and all of the links, plus some of the
+ # rooms and a pagination token.
+ expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)]
+ expected += [(room_id, ()) for room_id in room_ids[:6]]
+ self._assert_hierarchy(result, expected)
+ self.assertIn("next_batch", result)
+
+ # Check the next page.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(
+ self.user, self.space, limit=5, from_token=result["next_batch"]
+ )
+ )
+ # The result should have the space and the room in it, along with a link
+ # from space -> room.
+ expected = [(room_id, ()) for room_id in room_ids[6:]]
+ self._assert_hierarchy(result, expected)
+ self.assertNotIn("next_batch", result)
+
+ def test_invalid_pagination_token(self):
+ """An invalid pagination token, or changing other parameters, shoudl be rejected."""
+ room_ids = []
+ for i in range(1, 10):
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, room, self.token, order=str(i))
+ room_ids.append(room)
+ # The room created initially doesn't have an order, so comes last.
+ room_ids.append(self.room)
+
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+ )
+ self.assertIn("next_batch", result)
+
+ # Changing the room ID, suggested-only, or max-depth causes an error.
+ self.get_failure(
+ self.handler.get_room_hierarchy(
+ self.user, self.room, from_token=result["next_batch"]
+ ),
+ SynapseError,
+ )
+ self.get_failure(
+ self.handler.get_room_hierarchy(
+ self.user,
+ self.space,
+ suggested_only=True,
+ from_token=result["next_batch"],
+ ),
+ SynapseError,
+ )
+ self.get_failure(
+ self.handler.get_room_hierarchy(
+ self.user, self.space, max_depth=0, from_token=result["next_batch"]
+ ),
+ SynapseError,
+ )
+
+ # An invalid token is ignored.
+ self.get_failure(
+ self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"),
+ SynapseError,
+ )
+
+ def test_max_depth(self):
+ """Create a deep tree to test the max depth against."""
+ spaces = [self.space]
+ rooms = [self.room]
+ for _ in range(5):
+ spaces.append(
+ self.helper.create_room_as(
+ self.user,
+ tok=self.token,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: RoomTypes.SPACE
+ }
+ },
+ )
+ )
+ self._add_child(spaces[-2], spaces[-1], self.token)
+ rooms.append(self.helper.create_room_as(self.user, tok=self.token))
+ self._add_child(spaces[-1], rooms[-1], self.token)
+
+ # Test just the space itself.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space, max_depth=0)
+ )
+ expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])]
+ self._assert_hierarchy(result, expected)
+
+ # A single additional layer.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space, max_depth=1)
+ )
+ expected += [
+ (rooms[0], ()),
+ (spaces[1], [rooms[1], spaces[2]]),
+ ]
+ self._assert_hierarchy(result, expected)
+
+ # A few layers.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space, max_depth=3)
+ )
+ expected += [
+ (rooms[1], ()),
+ (spaces[2], [rooms[2], spaces[3]]),
+ (rooms[2], ()),
+ (spaces[3], [rooms[3], spaces[4]]),
+ ]
+ self._assert_hierarchy(result, expected)
+
+ def test_unknown_room_version(self):
+ """
+ If an room with an unknown room version is encountered it should not cause
+ the entire summary to skip.
+ """
+ # Poke the database and update the room version to an unknown one.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": self.room},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The result should have only the space, along with a link from space -> room.
+ expected = [(self.space, [self.room])]
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
+ def test_fed_complex(self):
+ """
+ Return data over federation and ensure that it is handled properly.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ subspace = "#subspace:" + fed_hostname
+ subroom = "#subroom:" + fed_hostname
+
+ # Generate some good data, and some bad data:
+ #
+ # * Event *back* to the root room.
+ # * Unrelated events / rooms
+ # * Multiple levels of events (in a not-useful order, e.g. grandchild
+ # events before child events).
+
+ # Note that these entries are brief, but should contain enough info.
+ requested_room_entry = _RoomEntry(
+ subspace,
+ {
+ "room_id": subspace,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ },
+ [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": subspace,
+ "state_key": subroom,
+ "content": {"via": [fed_hostname]},
+ }
+ ],
+ )
+ child_room = {
+ "room_id": subroom,
+ "world_readable": True,
+ }
+
+ async def summarize_remote_room(
+ _self, room, suggested_only, max_children, exclude_rooms
+ ):
+ return [
+ requested_room_entry,
+ _RoomEntry(
+ subroom,
+ {
+ "room_id": subroom,
+ "world_readable": True,
+ },
+ ),
+ ]
+
+ async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ return requested_room_entry, {subroom: child_room}, set()
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, subspace, self.token)
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room",
+ new=summarize_remote_room,
+ ):
+ result = self.get_success(
+ self.handler.get_space_summary(self.user, self.space)
+ )
+
+ expected = [
+ (self.space, [self.room, subspace]),
+ (self.room, ()),
+ (subspace, [subroom]),
+ (subroom, ()),
+ ]
+ self._assert_rooms(result, expected)
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
+ new=summarize_remote_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
+ def test_fed_filtering(self):
+ """
+ Rooms returned over federation should be properly filtered to only include
+ rooms the user has access to.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ subspace = "#subspace:" + fed_hostname
+
+ # Create a few rooms which will have different properties.
+ public_room = "#public:" + fed_hostname
+ knock_room = "#knock:" + fed_hostname
+ not_invited_room = "#not_invited:" + fed_hostname
+ invited_room = "#invited:" + fed_hostname
+ restricted_room = "#restricted:" + fed_hostname
+ restricted_accessible_room = "#restricted_accessible:" + fed_hostname
+ world_readable_room = "#world_readable:" + fed_hostname
+ joined_room = self.helper.create_room_as(self.user, tok=self.token)
+
+ # Poke an invite over federation into the database.
+ self._poke_fed_invite(invited_room, "@remote:" + fed_hostname)
+
+ # Note that these entries are brief, but should contain enough info.
+ children_rooms = (
+ (
+ public_room,
+ {
+ "room_id": public_room,
+ "world_readable": False,
+ "join_rules": JoinRules.PUBLIC,
+ },
+ ),
+ (
+ knock_room,
+ {
+ "room_id": knock_room,
+ "world_readable": False,
+ "join_rules": JoinRules.KNOCK,
+ },
+ ),
+ (
+ not_invited_room,
+ {
+ "room_id": not_invited_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ ),
+ (
+ invited_room,
+ {
+ "room_id": invited_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ ),
+ (
+ restricted_room,
+ {
+ "room_id": restricted_room,
+ "world_readable": False,
+ "join_rules": JoinRules.RESTRICTED,
+ "allowed_spaces": [],
+ },
+ ),
+ (
+ restricted_accessible_room,
+ {
+ "room_id": restricted_accessible_room,
+ "world_readable": False,
+ "join_rules": JoinRules.RESTRICTED,
+ "allowed_spaces": [self.room],
+ },
+ ),
+ (
+ world_readable_room,
+ {
+ "room_id": world_readable_room,
+ "world_readable": True,
+ "join_rules": JoinRules.INVITE,
+ },
+ ),
+ (
+ joined_room,
+ {
+ "room_id": joined_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ ),
+ )
+
+ subspace_room_entry = _RoomEntry(
+ subspace,
+ {
+ "room_id": subspace,
+ "world_readable": True,
+ },
+ # Place each room in the sub-space.
+ [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": subspace,
+ "state_key": room_id,
+ "content": {"via": [fed_hostname]},
+ }
+ for room_id, _ in children_rooms
+ ],
+ )
+
+ async def summarize_remote_room(
+ _self, room, suggested_only, max_children, exclude_rooms
+ ):
+ return [subspace_room_entry] + [
+ # A copy is made of the room data since the allowed_spaces key
+ # is removed.
+ _RoomEntry(child_room[0], dict(child_room[1]))
+ for child_room in children_rooms
+ ]
+
+ async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ return subspace_room_entry, dict(children_rooms), set()
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, subspace, self.token)
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room",
+ new=summarize_remote_room,
+ ):
+ result = self.get_success(
+ self.handler.get_space_summary(self.user, self.space)
+ )
+
+ expected = [
+ (self.space, [self.room, subspace]),
+ (self.room, ()),
+ (
+ subspace,
+ [
+ public_room,
+ knock_room,
+ not_invited_room,
+ invited_room,
+ restricted_room,
+ restricted_accessible_room,
+ world_readable_room,
+ joined_room,
+ ],
+ ),
+ (public_room, ()),
+ (knock_room, ()),
+ (invited_room, ()),
+ (restricted_accessible_room, ()),
+ (world_readable_room, ()),
+ (joined_room, ()),
+ ]
+ self._assert_rooms(result, expected)
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
+ new=summarize_remote_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
+ def test_fed_invited(self):
+ """
+ A room which the user was invited to should be included in the response.
+
+ This differs from test_fed_filtering in that the room itself is being
+ queried over federation, instead of it being included as a sub-room of
+ a space in the response.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_room = "#subroom:" + fed_hostname
+
+ # Poke an invite over federation into the database.
+ self._poke_fed_invite(fed_room, "@remote:" + fed_hostname)
+
+ fed_room_entry = _RoomEntry(
+ fed_room,
+ {
+ "room_id": fed_room,
+ "world_readable": False,
+ "join_rules": JoinRules.INVITE,
+ },
+ )
+
+ async def summarize_remote_room(
+ _self, room, suggested_only, max_children, exclude_rooms
+ ):
+ return [fed_room_entry]
+
+ async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ return fed_room_entry, {}, set()
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, fed_room, self.token)
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room",
+ new=summarize_remote_room,
+ ):
+ result = self.get_success(
+ self.handler.get_space_summary(self.user, self.space)
+ )
+
+ expected = [
+ (self.space, [self.room, fed_room]),
+ (self.room, ()),
+ (fed_room, ()),
+ ]
+ self._assert_rooms(result, expected)
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
+ new=summarize_remote_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
+
+class RoomSummaryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
+ self.hs = hs
+ self.handler = self.hs.get_room_summary_handler()
+
+ # Create a user.
+ self.user = self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ # Create a simple room.
+ self.room = self.helper.create_room_as(self.user, tok=self.token)
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.INVITE},
+ tok=self.token,
+ )
+
+ def test_own_room(self):
+ """Test a simple room created by the requester."""
+ result = self.get_success(self.handler.get_room_summary(self.user, self.room))
+ self.assertEqual(result.get("room_id"), self.room)
+
+ def test_visibility(self):
+ """A user not in a private room cannot get its summary."""
+ user2 = self.register_user("user2", "pass")
+ token2 = self.login("user2", "pass")
+
+ # The user cannot see the room.
+ self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError)
+
+ # If the room is made world-readable it should return a result.
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+ tok=self.token,
+ )
+ result = self.get_success(self.handler.get_room_summary(user2, self.room))
+ self.assertEqual(result.get("room_id"), self.room)
+
+ # Make it not world-readable again and confirm it results in an error.
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": HistoryVisibility.JOINED},
+ tok=self.token,
+ )
+ self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError)
+
+ # If the room is made public it should return a result.
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.PUBLIC},
+ tok=self.token,
+ )
+ result = self.get_success(self.handler.get_room_summary(user2, self.room))
+ self.assertEqual(result.get("room_id"), self.room)
+
+ # Join the space, make it invite-only again and results should be returned.
+ self.helper.join(self.room, user2, tok=token2)
+ self.helper.send_state(
+ self.room,
+ event_type=EventTypes.JoinRules,
+ body={"join_rule": JoinRules.INVITE},
+ tok=self.token,
+ )
+ result = self.get_success(self.handler.get_room_summary(user2, self.room))
+ self.assertEqual(result.get("room_id"), self.room)
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
new file mode 100644
index 0000000000..6f77b1237c
--- /dev/null
+++ b/tests/handlers/test_send_email.py
@@ -0,0 +1,112 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Tuple
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.address import IPv4Address
+from twisted.internet.defer import ensureDeferred
+from twisted.mail import interfaces, smtp
+
+from tests.server import FakeTransport
+from tests.unittest import HomeserverTestCase
+
+
+@implementer(interfaces.IMessageDelivery)
+class _DummyMessageDelivery:
+ def __init__(self):
+ # (recipient, message) tuples
+ self.messages: List[Tuple[smtp.Address, bytes]] = []
+
+ def receivedHeader(self, helo, origin, recipients):
+ return None
+
+ def validateFrom(self, helo, origin):
+ return origin
+
+ def record_message(self, recipient: smtp.Address, message: bytes):
+ self.messages.append((recipient, message))
+
+ def validateTo(self, user: smtp.User):
+ return lambda: _DummyMessage(self, user)
+
+
+@implementer(interfaces.IMessageSMTP)
+class _DummyMessage:
+ """IMessageSMTP implementation which saves the message delivered to it
+ to the _DummyMessageDelivery object.
+ """
+
+ def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
+ self._delivery = delivery
+ self._user = user
+ self._buffer: List[bytes] = []
+
+ def lineReceived(self, line):
+ self._buffer.append(line)
+
+ def eomReceived(self):
+ message = b"\n".join(self._buffer) + b"\n"
+ self._delivery.record_message(self._user.dest, message)
+ return defer.succeed(b"saved")
+
+ def connectionLost(self):
+ pass
+
+
+class SendEmailHandlerTestCase(HomeserverTestCase):
+ def test_send_email(self):
+ """Happy-path test that we can send email to a non-TLS server."""
+ h = self.hs.get_send_email_handler()
+ d = ensureDeferred(
+ h.send_email(
+ "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
+ )
+ )
+ # there should be an attempt to connect to localhost:25
+ self.assertEqual(len(self.reactor.tcpClients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
+ 0
+ ]
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 25)
+
+ # wire it up to an SMTP server
+ message_delivery = _DummyMessageDelivery()
+ server_protocol = smtp.ESMTP()
+ server_protocol.delivery = message_delivery
+ # make sure that the server uses the test reactor to set timeouts
+ server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
+
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
+ server_protocol.makeConnection(
+ FakeTransport(
+ client_protocol,
+ self.reactor,
+ peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ )
+ )
+
+ # the message should now get delivered
+ self.get_success(d, by=0.1)
+
+ # check it arrived
+ self.assertEqual(len(message_delivery.messages), 1)
+ user, msg = message_delivery.messages.pop()
+ self.assertEqual(str(user), "foo@bar.com")
+ self.assertIn(b"Subject: test subject", msg)
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
deleted file mode 100644
index 3f73ad7f94..0000000000
--- a/tests/handlers/test_space_summary.py
+++ /dev/null
@@ -1,543 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Any, Iterable, Optional, Tuple
-from unittest import mock
-
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- HistoryVisibility,
- JoinRules,
- Membership,
- RestrictedJoinRuleTypes,
- RoomTypes,
-)
-from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
-from synapse.handlers.space_summary import _child_events_comparison_key
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-
-from tests import unittest
-
-
-def _create_event(room_id: str, order: Optional[Any] = None):
- result = mock.Mock()
- result.room_id = room_id
- result.content = {}
- if order is not None:
- result.content["order"] = order
- return result
-
-
-def _order(*events):
- return sorted(events, key=_child_events_comparison_key)
-
-
-class TestSpaceSummarySort(unittest.TestCase):
- def test_no_order_last(self):
- """An event with no ordering is placed behind those with an ordering."""
- ev1 = _create_event("!abc:test")
- ev2 = _create_event("!xyz:test", "xyz")
-
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- def test_order(self):
- """The ordering should be used."""
- ev1 = _create_event("!abc:test", "xyz")
- ev2 = _create_event("!xyz:test", "abc")
-
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- def test_order_room_id(self):
- """Room ID is a tie-breaker for ordering."""
- ev1 = _create_event("!abc:test", "abc")
- ev2 = _create_event("!xyz:test", "abc")
-
- self.assertEqual([ev1, ev2], _order(ev1, ev2))
-
- def test_invalid_ordering_type(self):
- """Invalid orderings are considered the same as missing."""
- ev1 = _create_event("!abc:test", 1)
- ev2 = _create_event("!xyz:test", "xyz")
-
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- ev1 = _create_event("!abc:test", {})
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- ev1 = _create_event("!abc:test", [])
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- ev1 = _create_event("!abc:test", True)
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- def test_invalid_ordering_value(self):
- """Invalid orderings are considered the same as missing."""
- ev1 = _create_event("!abc:test", "foo\n")
- ev2 = _create_event("!xyz:test", "xyz")
-
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
- ev1 = _create_event("!abc:test", "a" * 51)
- self.assertEqual([ev2, ev1], _order(ev1, ev2))
-
-
-class SpaceSummaryTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs: HomeServer):
- self.hs = hs
- self.handler = self.hs.get_space_summary_handler()
-
- # Create a user.
- self.user = self.register_user("user", "pass")
- self.token = self.login("user", "pass")
-
- # Create a space and a child room.
- self.space = self.helper.create_room_as(
- self.user,
- tok=self.token,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- self.room = self.helper.create_room_as(self.user, tok=self.token)
- self._add_child(self.space, self.room, self.token)
-
- def _add_child(self, space_id: str, room_id: str, token: str) -> None:
- """Add a child room to a space."""
- self.helper.send_state(
- space_id,
- event_type=EventTypes.SpaceChild,
- body={"via": [self.hs.hostname]},
- tok=token,
- state_key=room_id,
- )
-
- def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None:
- """Assert that the expected room IDs are in the response."""
- self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms)
-
- def _assert_events(
- self, result: JsonDict, events: Iterable[Tuple[str, str]]
- ) -> None:
- """Assert that the expected parent / child room IDs are in the response."""
- self.assertCountEqual(
- [
- (event.get("room_id"), event.get("state_key"))
- for event in result["events"]
- ],
- events,
- )
-
- def test_simple_space(self):
- """Test a simple space with a single room."""
- result = self.get_success(self.handler.get_space_summary(self.user, self.space))
- # The result should have the space and the room in it, along with a link
- # from space -> room.
- self._assert_rooms(result, [self.space, self.room])
- self._assert_events(result, [(self.space, self.room)])
-
- def test_visibility(self):
- """A user not in a space cannot inspect it."""
- user2 = self.register_user("user2", "pass")
- token2 = self.login("user2", "pass")
-
- # The user cannot see the space.
- self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
-
- # If the space is made world-readable it should return a result.
- self.helper.send_state(
- self.space,
- event_type=EventTypes.RoomHistoryVisibility,
- body={"history_visibility": HistoryVisibility.WORLD_READABLE},
- tok=self.token,
- )
- result = self.get_success(self.handler.get_space_summary(user2, self.space))
- self._assert_rooms(result, [self.space, self.room])
- self._assert_events(result, [(self.space, self.room)])
-
- # Make it not world-readable again and confirm it results in an error.
- self.helper.send_state(
- self.space,
- event_type=EventTypes.RoomHistoryVisibility,
- body={"history_visibility": HistoryVisibility.JOINED},
- tok=self.token,
- )
- self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
-
- # Join the space and results should be returned.
- self.helper.join(self.space, user2, tok=token2)
- result = self.get_success(self.handler.get_space_summary(user2, self.space))
- self._assert_rooms(result, [self.space, self.room])
- self._assert_events(result, [(self.space, self.room)])
-
- def _create_room_with_join_rule(
- self, join_rule: str, room_version: Optional[str] = None, **extra_content
- ) -> str:
- """Create a room with the given join rule and add it to the space."""
- room_id = self.helper.create_room_as(
- self.user,
- room_version=room_version,
- tok=self.token,
- extra_content={
- "initial_state": [
- {
- "type": EventTypes.JoinRules,
- "state_key": "",
- "content": {
- "join_rule": join_rule,
- **extra_content,
- },
- }
- ]
- },
- )
- self._add_child(self.space, room_id, self.token)
- return room_id
-
- def test_filtering(self):
- """
- Rooms should be properly filtered to only include rooms the user has access to.
- """
- user2 = self.register_user("user2", "pass")
- token2 = self.login("user2", "pass")
-
- # Create a few rooms which will have different properties.
- public_room = self._create_room_with_join_rule(JoinRules.PUBLIC)
- knock_room = self._create_room_with_join_rule(
- JoinRules.KNOCK, room_version=RoomVersions.V7.identifier
- )
- not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
- invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
- self.helper.invite(invited_room, targ=user2, tok=self.token)
- restricted_room = self._create_room_with_join_rule(
- JoinRules.MSC3083_RESTRICTED,
- room_version=RoomVersions.MSC3083.identifier,
- allow=[],
- )
- restricted_accessible_room = self._create_room_with_join_rule(
- JoinRules.MSC3083_RESTRICTED,
- room_version=RoomVersions.MSC3083.identifier,
- allow=[
- {
- "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP,
- "room_id": self.space,
- "via": [self.hs.hostname],
- }
- ],
- )
- world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE)
- self.helper.send_state(
- world_readable_room,
- event_type=EventTypes.RoomHistoryVisibility,
- body={"history_visibility": HistoryVisibility.WORLD_READABLE},
- tok=self.token,
- )
- joined_room = self._create_room_with_join_rule(JoinRules.INVITE)
- self.helper.invite(joined_room, targ=user2, tok=self.token)
- self.helper.join(joined_room, user2, tok=token2)
-
- # Join the space.
- self.helper.join(self.space, user2, tok=token2)
- result = self.get_success(self.handler.get_space_summary(user2, self.space))
-
- self._assert_rooms(
- result,
- [
- self.space,
- self.room,
- public_room,
- knock_room,
- invited_room,
- restricted_accessible_room,
- world_readable_room,
- joined_room,
- ],
- )
- self._assert_events(
- result,
- [
- (self.space, self.room),
- (self.space, public_room),
- (self.space, knock_room),
- (self.space, not_invited_room),
- (self.space, invited_room),
- (self.space, restricted_room),
- (self.space, restricted_accessible_room),
- (self.space, world_readable_room),
- (self.space, joined_room),
- ],
- )
-
- def test_complex_space(self):
- """
- Create a "complex" space to see how it handles things like loops and subspaces.
- """
- # Create an inaccessible room.
- user2 = self.register_user("user2", "pass")
- token2 = self.login("user2", "pass")
- room2 = self.helper.create_room_as(user2, is_public=False, tok=token2)
- # This is a bit odd as "user" is adding a room they don't know about, but
- # it works for the tests.
- self._add_child(self.space, room2, self.token)
-
- # Create a subspace under the space with an additional room in it.
- subspace = self.helper.create_room_as(
- self.user,
- tok=self.token,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- subroom = self.helper.create_room_as(self.user, tok=self.token)
- self._add_child(self.space, subspace, token=self.token)
- self._add_child(subspace, subroom, token=self.token)
- # Also add the two rooms from the space into this subspace (causing loops).
- self._add_child(subspace, self.room, token=self.token)
- self._add_child(subspace, room2, self.token)
-
- result = self.get_success(self.handler.get_space_summary(self.user, self.space))
-
- # The result should include each room a single time and each link.
- self._assert_rooms(result, [self.space, self.room, subspace, subroom])
- self._assert_events(
- result,
- [
- (self.space, self.room),
- (self.space, room2),
- (self.space, subspace),
- (subspace, subroom),
- (subspace, self.room),
- (subspace, room2),
- ],
- )
-
- def test_fed_complex(self):
- """
- Return data over federation and ensure that it is handled properly.
- """
- fed_hostname = self.hs.hostname + "2"
- subspace = "#subspace:" + fed_hostname
- subroom = "#subroom:" + fed_hostname
-
- async def summarize_remote_room(
- _self, room, suggested_only, max_children, exclude_rooms
- ):
- # Return some good data, and some bad data:
- #
- # * Event *back* to the root room.
- # * Unrelated events / rooms
- # * Multiple levels of events (in a not-useful order, e.g. grandchild
- # events before child events).
-
- # Note that these entries are brief, but should contain enough info.
- rooms = [
- {
- "room_id": subspace,
- "world_readable": True,
- "room_type": RoomTypes.SPACE,
- },
- {
- "room_id": subroom,
- "world_readable": True,
- },
- ]
- event_content = {"via": [fed_hostname]}
- events = [
- {
- "room_id": subspace,
- "state_key": subroom,
- "content": event_content,
- },
- ]
- return rooms, events
-
- # Add a room to the space which is on another server.
- self._add_child(self.space, subspace, self.token)
-
- with mock.patch(
- "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room",
- new=summarize_remote_room,
- ):
- result = self.get_success(
- self.handler.get_space_summary(self.user, self.space)
- )
-
- self._assert_rooms(result, [self.space, self.room, subspace, subroom])
- self._assert_events(
- result,
- [
- (self.space, self.room),
- (self.space, subspace),
- (subspace, subroom),
- ],
- )
-
- def test_fed_filtering(self):
- """
- Rooms returned over federation should be properly filtered to only include
- rooms the user has access to.
- """
- fed_hostname = self.hs.hostname + "2"
- subspace = "#subspace:" + fed_hostname
-
- # Create a few rooms which will have different properties.
- public_room = "#public:" + fed_hostname
- knock_room = "#knock:" + fed_hostname
- not_invited_room = "#not_invited:" + fed_hostname
- invited_room = "#invited:" + fed_hostname
- restricted_room = "#restricted:" + fed_hostname
- restricted_accessible_room = "#restricted_accessible:" + fed_hostname
- world_readable_room = "#world_readable:" + fed_hostname
- joined_room = self.helper.create_room_as(self.user, tok=self.token)
-
- # Poke an invite over federation into the database.
- fed_handler = self.hs.get_federation_handler()
- event = make_event_from_dict(
- {
- "room_id": invited_room,
- "event_id": "!abcd:" + fed_hostname,
- "type": EventTypes.Member,
- "sender": "@remote:" + fed_hostname,
- "state_key": self.user,
- "content": {"membership": Membership.INVITE},
- "prev_events": [],
- "auth_events": [],
- "depth": 1,
- "origin_server_ts": 1234,
- }
- )
- self.get_success(
- fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
- )
-
- async def summarize_remote_room(
- _self, room, suggested_only, max_children, exclude_rooms
- ):
- # Note that these entries are brief, but should contain enough info.
- rooms = [
- {
- "room_id": public_room,
- "world_readable": False,
- "join_rules": JoinRules.PUBLIC,
- },
- {
- "room_id": knock_room,
- "world_readable": False,
- "join_rules": JoinRules.KNOCK,
- },
- {
- "room_id": not_invited_room,
- "world_readable": False,
- "join_rules": JoinRules.INVITE,
- },
- {
- "room_id": invited_room,
- "world_readable": False,
- "join_rules": JoinRules.INVITE,
- },
- {
- "room_id": restricted_room,
- "world_readable": False,
- "join_rules": JoinRules.MSC3083_RESTRICTED,
- "allowed_spaces": [],
- },
- {
- "room_id": restricted_accessible_room,
- "world_readable": False,
- "join_rules": JoinRules.MSC3083_RESTRICTED,
- "allowed_spaces": [self.room],
- },
- {
- "room_id": world_readable_room,
- "world_readable": True,
- "join_rules": JoinRules.INVITE,
- },
- {
- "room_id": joined_room,
- "world_readable": False,
- "join_rules": JoinRules.INVITE,
- },
- ]
-
- # Place each room in the sub-space.
- event_content = {"via": [fed_hostname]}
- events = [
- {
- "room_id": subspace,
- "state_key": room["room_id"],
- "content": event_content,
- }
- for room in rooms
- ]
-
- # Also include the subspace.
- rooms.insert(
- 0,
- {
- "room_id": subspace,
- "world_readable": True,
- },
- )
- return rooms, events
-
- # Add a room to the space which is on another server.
- self._add_child(self.space, subspace, self.token)
-
- with mock.patch(
- "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room",
- new=summarize_remote_room,
- ):
- result = self.get_success(
- self.handler.get_space_summary(self.user, self.space)
- )
-
- self._assert_rooms(
- result,
- [
- self.space,
- self.room,
- subspace,
- public_room,
- knock_room,
- invited_room,
- restricted_accessible_room,
- world_readable_room,
- joined_room,
- ],
- )
- self._assert_events(
- result,
- [
- (self.space, self.room),
- (self.space, subspace),
- (subspace, public_room),
- (subspace, knock_room),
- (subspace, not_invited_room),
- (subspace, invited_room),
- (subspace, restricted_room),
- (subspace, restricted_accessible_room),
- (subspace, world_readable_room),
- (subspace, joined_room),
- ],
- )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index e4059acda3..1ba4c05b9b 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -13,7 +13,7 @@
# limitations under the License.
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.storage.databases.main import stats
from tests import unittest
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 84f05f6c58..339c039914 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.room_versions import RoomVersions
from synapse.handlers.sync import SyncConfig
+from synapse.rest import admin
+from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
from synapse.types import UserID, create_requester
import tests.unittest
@@ -24,8 +31,14 @@ import tests.utils
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
- def prepare(self, reactor, clock, hs):
- self.hs = hs
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
@@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_unknown_room_version(self):
+ """
+ A room with an unknown room version should not break sync (and should be excluded).
+ """
+ inviter = self.register_user("creator", "pass", admin=True)
+ inviter_tok = self.login("@creator:test", "pass")
+
+ user = self.register_user("user", "pass")
+ tok = self.login("user", "pass")
+
+ # Do an initial sync on a different device.
+ requester = create_requester(user)
+ initial_result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user, device_id="dev")
+ )
+ )
+
+ # Create a room as the user.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ # Invite the user to the room as someone else.
+ invite_room = self.helper.create_room_as(inviter, tok=inviter_tok)
+ self.helper.invite(invite_room, targ=user, tok=inviter_tok)
+
+ knock_room = self.helper.create_room_as(
+ inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok
+ )
+ self.helper.send_state(
+ knock_room,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=inviter_tok,
+ )
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/knock/%s" % (knock_room,),
+ b"{}",
+ tok,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # The rooms should appear in the sync response.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user)
+ )
+ )
+ self.assertIn(joined_room, [r.room_id for r in result.joined])
+ self.assertIn(invite_room, [r.room_id for r in result.invited])
+ self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+ # Test a incremental sync (by providing a since_token).
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config=generate_sync_config(user, device_id="dev"),
+ since_token=initial_result.next_batch,
+ )
+ )
+ self.assertIn(joined_room, [r.room_id for r in result.joined])
+ self.assertIn(invite_room, [r.room_id for r in result.invited])
+ self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+ # Poke the database and update the room version to an unknown one.
+ for room_id in (joined_room, invite_room, knock_room):
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ # Blow away caches (supported room versions can only change due to a restart).
+ self.get_success(
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+ )
+ self.store._get_event_cache.clear()
+
+ # The rooms should be excluded from the sync response.
+ # Get a new request key.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user)
+ )
+ )
+ self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+ self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+ self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+ # The rooms should also not be in an incremental sync.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config=generate_sync_config(user, device_id="dev"),
+ since_token=initial_result.next_batch,
+ )
+ )
+ self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+ self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+ self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+
+_request_key = 0
+
-def generate_sync_config(user_id: str) -> SyncConfig:
+def generate_sync_config(
+ user_id: str, device_id: Optional[str] = "device_id"
+) -> SyncConfig:
+ """Generate a sync config (with a unique request key)."""
+ global _request_key
+ _request_key += 1
return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ user=UserID.from_string(user_id),
filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False,
- request_key="request_key",
- device_id="device_id",
+ request_key=("request_key", _request_key),
+ device_id=device_id,
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 549876dc85..a91d31ce61 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -16,10 +16,9 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
+from synapse.api.constants import UserTypes
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client import login, room, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -188,100 +187,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
- @override_config({"encryption_enabled_by_default_for_room_type": "all"})
- def test_encrypted_by_default_config_option_all(self):
- """Tests that invite-only and non-invite-only rooms have encryption enabled by
- default when the config option encryption_enabled_by_default_for_room_type is "all".
- """
- # Create a user
- user = self.register_user("user", "pass")
- user_token = self.login(user, "pass")
-
- # Create an invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
-
- # Check that the room has an encryption state event
- event_content = self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- )
- self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
-
- # Create a non invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
-
- # Check that the room has an encryption state event
- event_content = self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- )
- self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
-
- @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
- def test_encrypted_by_default_config_option_invite(self):
- """Tests that only new, invite-only rooms have encryption enabled by default when
- the config option encryption_enabled_by_default_for_room_type is "invite".
- """
- # Create a user
- user = self.register_user("user", "pass")
- user_token = self.login(user, "pass")
-
- # Create an invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
-
- # Check that the room has an encryption state event
- event_content = self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- )
- self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
-
- # Create a non invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
-
- # Check that the room does not have an encryption state event
- self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- expect_code=404,
- )
-
- @override_config({"encryption_enabled_by_default_for_room_type": "off"})
- def test_encrypted_by_default_config_option_off(self):
- """Tests that neither new invite-only nor non-invite-only rooms have encryption
- enabled by default when the config option
- encryption_enabled_by_default_for_room_type is "off".
- """
- # Create a user
- user = self.register_user("user", "pass")
- user_token = self.login(user, "pass")
-
- # Create an invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
-
- # Check that the room does not have an encryption state event
- self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- expect_code=404,
- )
-
- # Create a non invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
-
- # Check that the room does not have an encryption state event
- self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- expect_code=404,
- )
-
def test_spam_checker(self):
"""
A user which fails the spam checks will not appear in search results.
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index a37bce08c3..992d8f94fd 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
-from typing import Optional
-from unittest.mock import Mock
+import os
+from typing import Iterable, Optional
+from unittest.mock import Mock, patch
import treq
from netaddr import IPSet
@@ -22,11 +24,12 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
+from twisted.internet.interfaces import IProtocolFactory
from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
-from twisted.web.http import HTTPChannel
+from twisted.web.http import HTTPChannel, Request
from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
@@ -49,24 +52,6 @@ from tests.utils import default_config
logger = logging.getLogger(__name__)
-test_server_connection_factory = None
-
-
-def get_connection_factory():
- # this needs to happen once, but not until we are ready to run the first test
- global test_server_connection_factory
- if test_server_connection_factory is None:
- test_server_connection_factory = TestServerTLSConnectionFactory(
- sanlist=[
- b"DNS:testserv",
- b"DNS:target-server",
- b"DNS:xn--bcher-kva.com",
- b"IP:1.2.3.4",
- b"IP:::1",
- ]
- )
- return test_server_connection_factory
-
# Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(result):
@@ -100,24 +85,38 @@ class MatrixFederationAgentTests(unittest.TestCase):
had_well_known_cache=self.had_well_known_cache,
)
- self.agent = MatrixFederationAgent(
- reactor=self.reactor,
- tls_client_options_factory=self.tls_factory,
- user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
- ip_blacklist=IPSet(),
- _srv_resolver=self.mock_resolver,
- _well_known_resolver=self.well_known_resolver,
- )
-
- def _make_connection(self, client_factory, expected_sni):
+ def _make_connection(
+ self,
+ client_factory: IProtocolFactory,
+ ssl: bool = True,
+ expected_sni: bytes = None,
+ tls_sanlist: Optional[Iterable[bytes]] = None,
+ ) -> HTTPChannel:
"""Builds a test server, and completes the outgoing client connection
+ Args:
+ client_factory: the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ ssl: If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+ False is set only for when proxy expect http connection.
+ Otherwise federation requests use always https.
+
+ expected_sni: the expected SNI value
+
+ tls_sanlist: list of SAN entries for the TLS cert presented by the server.
Returns:
- HTTPChannel: the test server
+ the server Protocol returned by server_factory
"""
# build the test server
- server_tls_protocol = _build_test_server(get_connection_factory())
+ server_factory = _get_test_protocol_factory()
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
+
+ server_protocol = server_factory.buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -128,35 +127,39 @@ class MatrixFederationAgentTests(unittest.TestCase):
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(
- FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ FakeTransport(server_protocol, self.reactor, client_protocol)
)
- # tell the server tls protocol to send its stuff back to the client, too
- server_tls_protocol.makeConnection(
- FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
)
- # grab a hold of the TLS connection, in case it gets torn down
- server_tls_connection = server_tls_protocol._tlsConnection
-
- # fish the test server back out of the server-side TLS protocol.
- http_protocol = server_tls_protocol.wrappedProtocol
+ if ssl:
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_protocol.wrappedProtocol
+ # grab a hold of the TLS connection, in case it gets torn down
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
- # give the reactor a pump to get the TLS juices flowing.
- self.reactor.pump((0.1,))
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
# check the SNI
- server_name = server_tls_connection.get_servername()
- self.assertEqual(
- server_name,
- expected_sni,
- "Expected SNI %s but got %s" % (expected_sni, server_name),
- )
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ f"Expected SNI {expected_sni!s} but got {server_name!s}",
+ )
return http_protocol
@defer.inlineCallbacks
- def _make_get_request(self, uri):
+ def _make_get_request(self, uri: bytes):
"""
Sends a simple GET request via the agent, and checks its logcontext management
"""
@@ -180,20 +183,20 @@ class MatrixFederationAgentTests(unittest.TestCase):
def _handle_well_known_connection(
self,
- client_factory,
- expected_sni,
- content,
+ client_factory: IProtocolFactory,
+ expected_sni: bytes,
+ content: bytes,
response_headers: Optional[dict] = None,
- ):
+ ) -> HTTPChannel:
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
Args:
- client_factory (IProtocolFactory): outgoing connection
- expected_sni (bytes): SNI that we expect the outgoing connection to send
- content (bytes): content to send back as the .well-known
+ client_factory: outgoing connection
+ expected_sni: SNI that we expect the outgoing connection to send
+ content: content to send back as the .well-known
Returns:
- HTTPChannel: server impl
+ server impl
"""
# make the connection for .well-known
well_known_server = self._make_connection(
@@ -209,7 +212,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
return well_known_server
def _send_well_known_response(
- self, request, content, headers: Optional[dict] = None
+ self,
+ request: Request,
+ content: bytes,
+ headers: Optional[dict] = None,
):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
@@ -225,10 +231,37 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((0.1,))
- def test_get(self):
+ def _make_agent(self) -> MatrixFederationAgent:
"""
- happy-path test of a GET request with an explicit port
+ If a proxy server is set, the MatrixFederationAgent must be created again
+ because it is created too early during setUp
"""
+ return MatrixFederationAgent(
+ reactor=self.reactor,
+ tls_client_options_factory=self.tls_factory,
+ user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
+ ip_whitelist=IPSet(),
+ ip_blacklist=IPSet(),
+ _srv_resolver=self.mock_resolver,
+ _well_known_resolver=self.well_known_resolver,
+ )
+
+ def test_get(self):
+ """happy-path test of a GET request with an explicit port"""
+ self._do_get()
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "proxy.com", "no_proxy": "testserv"},
+ )
+ def test_get_bypass_proxy(self):
+ """test of a GET request with an explicit port and bypass proxy"""
+ self._do_get()
+
+ def _do_get(self):
+ """test of a GET request with an explicit port"""
+ self.agent = self._make_agent()
+
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
@@ -282,10 +315,188 @@ class MatrixFederationAgentTests(unittest.TestCase):
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
+ @patch.dict(
+ os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
+ )
+ def test_get_via_http_proxy(self):
+ """test for federation request through a http proxy"""
+ self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_get_via_http_proxy_with_auth(self):
+ """test for federation request through a http proxy with authentication"""
+ self._do_get_via_proxy(
+ expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
+ )
+
+ @patch.dict(
+ os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
+ )
+ def test_get_via_https_proxy(self):
+ """test for federation request through a https proxy"""
+ self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_get_via_https_proxy_with_auth(self):
+ """test for federation request through a https proxy with authentication"""
+ self._do_get_via_proxy(
+ expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
+ )
+
+ def _do_get_via_proxy(
+ self,
+ expect_proxy_ssl: bool = False,
+ expected_auth_credentials: Optional[bytes] = None,
+ ):
+ """Send a https federation request via an agent and check that it is correctly
+ received at the proxy and client. The proxy can use either http or https.
+ Args:
+ expect_proxy_ssl: True if we expect the request to connect to the proxy via https.
+ expected_auth_credentials: credentials we expect to be presented to authenticate at the proxy
+ """
+ self.agent = self._make_agent()
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ self.reactor.lookups["proxy.com"] = "9.9.9.9"
+ test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ # make sure we are connecting to the proxy
+ self.assertEqual(host, "9.9.9.9")
+ self.assertEqual(port, 1080)
+
+ # make a test server to act as the proxy, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory,
+ ssl=expect_proxy_ssl,
+ tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
+ expected_sni=b"proxy.com" if expect_proxy_ssl else None,
+ )
+
+ assert isinstance(proxy_server, HTTPChannel)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"testserv:8448")
+
+ # Check whether auth credentials have been supplied to the proxy
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+
+ if expected_auth_credentials is not None:
+ # Compute the correct header value for Proxy-Authorization
+ encoded_credentials = base64.b64encode(expected_auth_credentials)
+ expected_header_value = b"Basic " + encoded_credentials
+
+ # Validate the header's value
+ self.assertIn(expected_header_value, proxy_auth_header_values)
+ else:
+ # Check that the Proxy-Authorization header has not been supplied to the proxy
+ self.assertIsNone(proxy_auth_header_values)
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ request.finish()
+
+ # now we make another test server to act as the upstream HTTP server.
+ server_ssl_protocol = _wrap_server_factory_for_tls(
+ _get_test_protocol_factory()
+ ).buildProtocol(None)
+
+ # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
+ proxy_server_transport = proxy_server.transport
+ server_ssl_protocol.makeConnection(proxy_server_transport)
+
+ # ... and replace the protocol on the proxy's transport with the
+ # TLSMemoryBIOProtocol for the test server, so that incoming traffic
+ # to the proxy gets sent over to the HTTP(s) server.
+
+ # See also comment at `_do_https_request_via_proxy`
+ # in ../test_proxyagent.py for more details
+ if expect_proxy_ssl:
+ assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol)
+ proxy_server_transport.wrappedProtocol = server_ssl_protocol
+ else:
+ assert isinstance(proxy_server_transport, FakeTransport)
+ client_protocol = proxy_server_transport.other
+ c2s_transport = client_protocol.transport
+ c2s_transport.other = server_ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = server_ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"testserv"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ f"Expected SNI {expected_sni!s} but got {server_name!s}",
+ )
+
+ # now there should be a pending request
+ http_server = server_ssl_protocol.wrappedProtocol
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
+ # Check that the destination server DID NOT receive proxy credentials
+ self.assertIsNone(request.requestHeaders.getRawHeaders(b"Proxy-Authorization"))
+ content = request.content.read()
+ self.assertEqual(content, b"")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # send the headers
+ request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"])
+ request.write("")
+
+ self.reactor.pump((0.1,))
+
+ response = self.successResultOf(test_d)
+
+ # that should give us a Response object
+ self.assertEqual(response.code, 200)
+
+ # Send the body
+ request.write('{ "a": 1 }'.encode("ascii"))
+ request.finish()
+
+ self.reactor.pump((0.1,))
+
+ # check it can be read
+ json = self.successResultOf(treq.json_content(response))
+ self.assertEqual(json, {"a": 1})
+
def test_get_ip_address(self):
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
+ self.agent = self._make_agent()
+
# there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
@@ -320,6 +531,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name contains an explicit IPv6 address
(with no port)
"""
+ self.agent = self._make_agent()
# there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1"
@@ -355,6 +567,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name contains an explicit IPv6 address
(with explicit port)
"""
+ self.agent = self._make_agent()
# there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1"
@@ -389,6 +602,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
+ self.agent = self._make_agent()
+
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"
@@ -441,6 +656,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name contains an explicit IP, but
the server cert doesn't cover it
"""
+ self.agent = self._make_agent()
+
# there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.5"] = "1.2.3.5"
@@ -471,6 +688,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
+ self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -524,6 +742,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_get_well_known(self):
"""Test the behaviour when the .well-known delegates elsewhere"""
+ self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -587,6 +806,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
+ self.agent = self._make_agent()
+
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -675,6 +896,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
+ self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -743,6 +965,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
reactor=self.reactor,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
+ ip_whitelist=IPSet(),
ip_blacklist=IPSet(),
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
@@ -780,6 +1003,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when there is a single SRV record
"""
+ self.agent = self._make_agent()
+
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"srvtarget", port=8443)]
)
@@ -820,6 +1045,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known redirects to a place where there
is a SRV.
"""
+ self.agent = self._make_agent()
+
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["srvtarget"] = "5.6.7.8"
@@ -876,6 +1103,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
+ self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
@@ -937,6 +1165,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
+ self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
@@ -1140,6 +1369,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails."""
+ self.agent = self._make_agent()
+
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[
Server(host=b"target.com", port=8443),
@@ -1266,34 +1497,49 @@ def _check_logcontext(context):
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
-def _build_test_server(connection_creator):
- """Construct a test server
-
- This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
-
+def _wrap_server_factory_for_tls(
+ factory: IProtocolFactory, sanlist: Iterable[bytes] = None
+) -> IProtocolFactory:
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
Args:
- connection_creator (IOpenSSLServerConnectionCreator): thing to build
- SSL connections
- sanlist (list[bytes]): list of the SAN entries for the cert returned
- by the server
+ factory: protocol factory to wrap
+ sanlist: list of domains the cert should be valid for
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [
+ b"DNS:testserv",
+ b"DNS:target-server",
+ b"DNS:xn--bcher-kva.com",
+ b"IP:1.2.3.4",
+ b"IP:::1",
+ ]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+def _get_test_protocol_factory() -> IProtocolFactory:
+ """Get a protocol Factory which will build an HTTPChannel
Returns:
- TLSMemoryBIOProtocol
+ interfaces.IProtocolFactory
"""
server_factory = Factory.forProtocol(HTTPChannel)
+
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
- server_tls_factory = TLSMemoryBIOFactory(
- connection_creator, isClient=False, wrappedFactory=server_factory
- )
-
- return server_tls_factory.buildProtocol(None)
+ return server_factory
-def _log_request(request):
+def _log_request(request: str):
"""Implements Factory.log, which is expected by Request.finish"""
- logger.info("Completed request %s", request)
+ logger.info(f"Completed request {request}")
@implementer(IPolicyForHTTPS)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index e5865c161d..2db77c6a73 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -29,7 +29,8 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
from synapse.http.client import BlacklistingReactorWrapper
-from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy
+from synapse.http.connectproxyclient import ProxyCredentials
+from synapse.http.proxyagent import ProxyAgent, parse_proxy
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
from tests.server import FakeTransport, ThreadedMemoryReactorClock
@@ -392,7 +393,9 @@ class MatrixFederationAgentTests(TestCase):
"""
Tests that requests can be made through a proxy.
"""
- self._do_http_request_via_proxy(ssl=False, auth_credentials=None)
+ self._do_http_request_via_proxy(
+ expect_proxy_ssl=False, expected_auth_credentials=None
+ )
@patch.dict(
os.environ,
@@ -402,13 +405,17 @@ class MatrixFederationAgentTests(TestCase):
"""
Tests that authenticated requests can be made through a proxy.
"""
- self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies")
+ self._do_http_request_via_proxy(
+ expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
+ )
@patch.dict(
os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
)
def test_http_request_via_https_proxy(self):
- self._do_http_request_via_proxy(ssl=True, auth_credentials=None)
+ self._do_http_request_via_proxy(
+ expect_proxy_ssl=True, expected_auth_credentials=None
+ )
@patch.dict(
os.environ,
@@ -418,12 +425,16 @@ class MatrixFederationAgentTests(TestCase):
},
)
def test_http_request_via_https_proxy_with_auth(self):
- self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies")
+ self._do_http_request_via_proxy(
+ expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
+ )
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
"""Tests that TLS-encrypted requests can be made through a proxy"""
- self._do_https_request_via_proxy(ssl=False, auth_credentials=None)
+ self._do_https_request_via_proxy(
+ expect_proxy_ssl=False, expected_auth_credentials=None
+ )
@patch.dict(
os.environ,
@@ -431,14 +442,18 @@ class MatrixFederationAgentTests(TestCase):
)
def test_https_request_via_proxy_with_auth(self):
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
- self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies")
+ self._do_https_request_via_proxy(
+ expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
+ )
@patch.dict(
os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
)
def test_https_request_via_https_proxy(self):
"""Tests that TLS-encrypted requests can be made through a proxy"""
- self._do_https_request_via_proxy(ssl=True, auth_credentials=None)
+ self._do_https_request_via_proxy(
+ expect_proxy_ssl=True, expected_auth_credentials=None
+ )
@patch.dict(
os.environ,
@@ -446,20 +461,22 @@ class MatrixFederationAgentTests(TestCase):
)
def test_https_request_via_https_proxy_with_auth(self):
"""Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
- self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies")
+ self._do_https_request_via_proxy(
+ expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
+ )
def _do_http_request_via_proxy(
self,
- ssl: bool = False,
- auth_credentials: Optional[bytes] = None,
+ expect_proxy_ssl: bool = False,
+ expected_auth_credentials: Optional[bytes] = None,
):
"""Send a http request via an agent and check that it is correctly received at
the proxy. The proxy can use either http or https.
Args:
- ssl: True if we expect the request to connect via https to proxy
- auth_credentials: credentials to authenticate at proxy
+ expect_proxy_ssl: True if we expect the request to connect via https to proxy
+ expected_auth_credentials: credentials to authenticate at proxy
"""
- if ssl:
+ if expect_proxy_ssl:
agent = ProxyAgent(
self.reactor, use_proxy=True, contextFactory=get_test_https_policy()
)
@@ -480,9 +497,9 @@ class MatrixFederationAgentTests(TestCase):
http_server = self._make_connection(
client_factory,
_get_test_protocol_factory(),
- ssl=ssl,
- tls_sanlist=[b"DNS:proxy.com"] if ssl else None,
- expected_sni=b"proxy.com" if ssl else None,
+ ssl=expect_proxy_ssl,
+ tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
+ expected_sni=b"proxy.com" if expect_proxy_ssl else None,
)
# the FakeTransport is async, so we need to pump the reactor
@@ -498,9 +515,9 @@ class MatrixFederationAgentTests(TestCase):
b"Proxy-Authorization"
)
- if auth_credentials is not None:
+ if expected_auth_credentials is not None:
# Compute the correct header value for Proxy-Authorization
- encoded_credentials = base64.b64encode(auth_credentials)
+ encoded_credentials = base64.b64encode(expected_auth_credentials)
expected_header_value = b"Basic " + encoded_credentials
# Validate the header's value
@@ -523,14 +540,14 @@ class MatrixFederationAgentTests(TestCase):
def _do_https_request_via_proxy(
self,
- ssl: bool = False,
- auth_credentials: Optional[bytes] = None,
+ expect_proxy_ssl: bool = False,
+ expected_auth_credentials: Optional[bytes] = None,
):
"""Send a https request via an agent and check that it is correctly received at
the proxy and client. The proxy can use either http or https.
Args:
- ssl: True if we expect the request to connect via https to proxy
- auth_credentials: credentials to authenticate at proxy
+ expect_proxy_ssl: True if we expect the request to connect via https to proxy
+ expected_auth_credentials: credentials to authenticate at proxy
"""
agent = ProxyAgent(
self.reactor,
@@ -552,9 +569,9 @@ class MatrixFederationAgentTests(TestCase):
proxy_server = self._make_connection(
client_factory,
_get_test_protocol_factory(),
- ssl=ssl,
- tls_sanlist=[b"DNS:proxy.com"] if ssl else None,
- expected_sni=b"proxy.com" if ssl else None,
+ ssl=expect_proxy_ssl,
+ tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
+ expected_sni=b"proxy.com" if expect_proxy_ssl else None,
)
assert isinstance(proxy_server, HTTPChannel)
@@ -570,9 +587,9 @@ class MatrixFederationAgentTests(TestCase):
b"Proxy-Authorization"
)
- if auth_credentials is not None:
+ if expected_auth_credentials is not None:
# Compute the correct header value for Proxy-Authorization
- encoded_credentials = base64.b64encode(auth_credentials)
+ encoded_credentials = base64.b64encode(expected_auth_credentials)
expected_header_value = b"Basic " + encoded_credentials
# Validate the header's value
@@ -606,7 +623,7 @@ class MatrixFederationAgentTests(TestCase):
# Protocol to implement the proxy, which starts out by forwarding to an
# HTTPChannel (to implement the CONNECT command) and can then be switched
# into a mode where it forwards its traffic to another Protocol.)
- if ssl:
+ if expect_proxy_ssl:
assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol)
proxy_server_transport.wrappedProtocol = server_ssl_protocol
else:
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 81d9e2f484..7dd519cd44 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -20,7 +20,7 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin
-from synapse.rest.client.v1 import login, presence, room
+from synapse.rest.client import login, presence, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
@@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
+ def test_get_userinfo_by_id(self):
+ user_id = self.register_user("alice", "1234")
+ found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ self.assertEqual(found_user.user_id.to_string(), user_id)
+ self.assertIdentical(found_user.is_admin, False)
+
+ def test_get_userinfo_by_id__no_user_found(self):
+ found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
+ self.assertIsNone(found_user)
+
def test_sending_events_into_room(self):
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e04bc5c9a6..fa8018e5a7 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import email.message
import os
+from typing import Dict, List, Sequence, Tuple
import attr
import pkg_resources
@@ -21,7 +22,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests.unittest import HomeserverTestCase
@@ -45,14 +46,6 @@ class EmailPusherTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- # List[Tuple[Deferred, args, kwargs]]
- self.email_attempts = []
-
- def sendmail(*args, **kwargs):
- d = Deferred()
- self.email_attempts.append((d, args, kwargs))
- return d
-
config = self.default_config()
config["email"] = {
"enable_notifs": True,
@@ -75,7 +68,18 @@ class EmailPusherTests(HomeserverTestCase):
config["public_baseurl"] = "aaa"
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ hs = self.setup_test_homeserver(config=config)
+
+ # List[Tuple[Deferred, args, kwargs]]
+ self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = []
+
+ def sendmail(*args, **kwargs):
+ # This mocks out synapse.reactor.send_email._sendmail.
+ d = Deferred()
+ self.email_attempts.append((d, args, kwargs))
+ return d
+
+ hs.get_send_email_handler()._sendmail = sendmail
return hs
@@ -123,6 +127,8 @@ class EmailPusherTests(HomeserverTestCase):
)
)
+ self.auth_handler = hs.get_auth_handler()
+
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
their email.
@@ -251,6 +257,39 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_room_notifications_include_avatar(self):
+ # Create a room and set its avatar.
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.send_state(
+ room, "m.room.avatar", {"url": "mxc://DUMMY_MEDIA_ID"}, self.access_token
+ )
+
+ # Invite two other uses.
+ for other in self.others:
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=room, user=other.id, tok=other.token)
+
+ # The other users send some messages.
+ # TODO It seems that two messages are required to trigger an email?
+ self.helper.send(room, body="Alpha", tok=self.others[0].token)
+ self.helper.send(room, body="Beta", tok=self.others[1].token)
+
+ # We should get emailed about those messages
+ args, kwargs = self._check_for_mail()
+
+ # That email should contain the room's avatar
+ msg: bytes = args[5]
+ # Multipart: plain text, base 64 encoded; html, base 64 encoded
+ html = (
+ email.message_from_bytes(msg)
+ .get_payload()[1]
+ .get_payload(decode=True)
+ .decode()
+ )
+ self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
+
def test_empty_room(self):
"""All users leaving a room shouldn't cause the pusher to break."""
# Create a simple room with two users
@@ -303,9 +342,95 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
- def _check_for_mail(self):
- """Check that the user receives an email notification"""
+ def test_no_email_sent_after_removed(self):
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room,
+ src=self.user_id,
+ tok=self.access_token,
+ targ=self.others[0].id,
+ )
+ self.helper.join(
+ room=room,
+ user=self.others[0].id,
+ tok=self.others[0].token,
+ )
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+ # disassociate the user's email address
+ self.get_success(
+ self.auth_handler.delete_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address="a@example.com",
+ )
+ )
+
+ # check that the pusher for that email address has been deleted
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
+ def test_remove_unlinked_pushers_background_job(self):
+ """Checks that all existing pushers associated with unlinked email addresses are removed
+ upon running the remove_deleted_email_pushers background update.
+ """
+ # disassociate the user's email address manually (without deleting the pusher).
+ # This resembles the old behaviour, which the background update below is intended
+ # to clean up.
+ self.get_success(
+ self.hs.get_datastore().user_delete_threepid(
+ self.user_id, "email", "a@example.com"
+ )
+ )
+
+ # Run the "remove_deleted_email_pushers" background job
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "remove_deleted_email_pushers",
+ "progress_json": "{}",
+ "depends_on": None,
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.hs.get_datastore().db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ while not self.get_success(
+ self.hs.get_datastore().db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.hs.get_datastore().db_pool.updates.do_next_background_update(100),
+ by=0.1,
+ )
+
+ # Check that all pushers with unlinked addresses were deleted
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
+ def _check_for_mail(self) -> Tuple[Sequence, Dict]:
+ """
+ Assert that synapse sent off exactly one email notification.
+
+ Returns:
+ args and kwargs passed to synapse.reactor.send_email._sendmail for
+ that notification.
+ """
# Get the stream ordering before it gets sent
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
@@ -328,8 +453,9 @@ class EmailPusherTests(HomeserverTestCase):
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
+ deferred, sendmail_args, sendmail_kwargs = self.email_attempts[0]
# Make the email succeed
- self.email_attempts[0][0].callback(True)
+ deferred.callback(True)
self.pump()
# One email was attempted to be sent
@@ -345,3 +471,4 @@ class EmailPusherTests(HomeserverTestCase):
# Reset the attempts.
self.email_attempts = []
+ return sendmail_args, sendmail_kwargs
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index ffd75b1491..c068d329a9 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,8 +18,7 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import receipts
+from synapse.rest.client import login, receipts, room
from tests.unittest import HomeserverTestCase, override_config
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index db80a0bdbd..b25a06b427 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
"invite",
event.event_id,
event.internal_metadata.stream_ordering,
+ RoomVersions.V1.identifier,
)
],
)
@@ -216,7 +217,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, expected_pos)},
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -305,7 +306,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
expected_pos = PersistedEventPosition(
"master", j2.internal_metadata.stream_ordering
)
- self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
+ self.assertEqual(
+ joined_rooms,
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
event_id = 0
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 666008425a..f198a94887 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -24,7 +24,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamRow,
)
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 1346e0e160..43a16bb141 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from synapse.rest.client.v2_alpha import register
+from synapse.rest.client import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, make_request
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index b9751efdc5..995097d72c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from synapse.rest.client.v2_alpha import register
+from synapse.rest.client import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index a0c710f855..92a5b53e11 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -205,7 +205,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_federation_handler()
+ federation = self.hs.get_federation_event_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index ffa425328f..ac419f0db3 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -22,7 +22,7 @@ from twisted.web.http import HTTPChannel
from twisted.web.server import Request
from synapse.rest import admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client import login
from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 1e4e3821b9..4094a75f36 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from twisted.internet import defer
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests.replication._base import BaseMultiWorkerStreamTestCase
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index f3615af97e..0a6e4795ee 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -16,8 +16,7 @@ from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client import login, room, sync
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index a7c6e595b9..bfa638fb4b 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -24,8 +24,7 @@ import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.admin import VersionServlet
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import groups
+from synapse.rest.client import groups, login, room
from tests import unittest
from tests.server import FakeSite, make_request
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 120730b764..a3679be205 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import urllib.parse
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.errors import Codes
-from synapse.rest.client.v1 import login
+from synapse.rest.client import login
from tests import unittest
@@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.other_user_device_id,
)
- def test_no_auth(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get a device of an user without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("PUT", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("DELETE", self.url, b"{}")
+ channel = self.make_request(method, self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_requester_is_no_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
channel = self.make_request(
- "GET",
- self.url,
- access_token=self.other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "PUT",
- self.url,
- access_token=self.other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
self.url,
access_token=self.other_user_token,
)
@@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_user_does_not_exist(self, method: str):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
@@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "GET",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_user_is_not_local(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_user_is_not_local(self, method: str):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
@@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
}
- body = json.dumps(update)
channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=update,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
Tests a normal successful update of display name
"""
# Set new display_name
- body = json.dumps({"display_name": "new displayname"})
channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"display_name": "new displayname"},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a remove of a device that does not exist returns 200.
"""
- body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"devices": ["unknown_device1", "unknown_device2"]},
)
# Delete unknown devices returns status 200
@@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
device_ids.append(str(d["device_id"]))
# Delete devices
- body = json.dumps({"devices": device_ids})
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"devices": device_ids},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index f15d1cf6f7..e9ef89731f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -16,8 +16,7 @@ import json
import synapse.rest.admin
from synapse.api.errors import Codes
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import report_event
+from synapse.rest.client import login, report_event, room
from tests import unittest
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 7198fd293f..972d60570c 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -20,7 +20,7 @@ from parameterized import parameterized
import synapse.rest.admin
from synapse.api.errors import Codes
-from synapse.rest.client.v1 import login, profile, room
+from synapse.rest.client import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/registration_tokens"
+
+ def _new_token(self, **kwargs):
+ """Helper function to create a token."""
+ token = kwargs.get(
+ "token",
+ "".join(random.choices(string.ascii_letters, k=8)),
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": kwargs.get("uses_allowed", None),
+ "pending": kwargs.get("pending", 0),
+ "completed": kwargs.get("completed", 0),
+ "expiry_time": kwargs.get("expiry_time", None),
+ },
+ )
+ )
+ return token
+
+ # CREATION
+
+ def test_create_no_auth(self):
+ """Try to create a token without authentication."""
+ channel = self.make_request("POST", self.url + "/new", {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_create_requester_not_admin(self):
+ """Try to create a token while not an admin."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_create_using_defaults(self):
+ """Create a token using all the defaults."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_specifying_fields(self):
+ """Create a token specifying the value of all fields."""
+ data = {
+ "token": "abcd",
+ "uses_allowed": 1,
+ "expiry_time": self.clock.time_msec() + 1000000,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_with_null_value(self):
+ """Create a token specifying unlimited uses and no expiry."""
+ data = {
+ "uses_allowed": None,
+ "expiry_time": None,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_token_too_long(self):
+ """Check token longer than 64 chars is invalid."""
+ data = {"token": "a" * 65}
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_invalid_chars(self):
+ """Check you can't create token with invalid characters."""
+ data = {
+ "token": "abc/def",
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_already_exists(self):
+ """Check you can't create token that already exists."""
+ data = {
+ "token": "abcd",
+ }
+
+ channel1 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+ channel2 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_unable_to_generate_token(self):
+ """Check right error is raised when server can't generate unique token."""
+ # Create all possible single character tokens
+ tokens = []
+ for c in string.ascii_letters + string.digits + "-_":
+ tokens.append(
+ {
+ "token": c,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ }
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ "registration_tokens",
+ tokens,
+ "create_all_registration_tokens",
+ )
+ )
+
+ # Check creating a single character token fails with a 500 status code
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_create_uses_allowed(self):
+ """Check you can only create a token with good values for uses_allowed."""
+ # Should work with 0 (token is invalid from the start)
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+ # Should fail with negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_expiry_time(self):
+ """Check you can't create a token with an invalid expiry_time."""
+ # Should fail with a time in the past
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() - 10000},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() + 1000000.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_length(self):
+ """Check you can only generate a token with a valid length."""
+ # Should work with 64
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 64},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 64)
+
+ # Should fail with 0
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 8.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with 65
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 65},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # UPDATING
+
+ def test_update_no_auth(self):
+ """Try to update a token without authentication."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_update_requester_not_admin(self):
+ """Try to update a token while not an admin."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_update_non_existent(self):
+ """Try to update a token that doesn't exist."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234",
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_update_uses_allowed(self):
+ """Test updating just uses_allowed."""
+ # Create new token using default values
+ token = self._new_token()
+
+ # Should succeed with 1
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with 0 (makes token invalid)
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should fail with a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_expiry_time(self):
+ """Test updating just expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ # Should succeed with a time in the future
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should fail with a time in the past
+ past_time = self.clock.time_msec() - 10000
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": past_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time + 0.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_both(self):
+ """Test updating both uses_allowed and expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ data = {
+ "uses_allowed": 1,
+ "expiry_time": new_expiry_time,
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+ def test_update_invalid_type(self):
+ """Test using invalid types doesn't work."""
+ # Create new token using default values
+ token = self._new_token()
+
+ data = {
+ "uses_allowed": False,
+ "expiry_time": "1626430124000",
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # DELETING
+
+ def test_delete_no_auth(self):
+ """Try to delete a token without authentication."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_delete_requester_not_admin(self):
+ """Try to delete a token while not an admin."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_non_existent(self):
+ """Try to delete a token that doesn't exist."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_delete(self):
+ """Test deleting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # GETTING ONE
+
+ def test_get_no_auth(self):
+ """Try to get a token without authentication."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_get_requester_not_admin(self):
+ """Try to get a token while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_get_non_existent(self):
+ """Try to get a token that doesn't exist."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_get(self):
+ """Test getting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], token)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ # LISTING
+
+ def test_list_no_auth(self):
+ """Try to list tokens without authentication."""
+ channel = self.make_request("GET", self.url, {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_list_requester_not_admin(self):
+ """Try to list tokens while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_list_all(self):
+ """Test listing all tokens."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+ token_info = channel.json_body["registration_tokens"][0]
+ self.assertEqual(token_info["token"], token)
+ self.assertIsNone(token_info["uses_allowed"])
+ self.assertIsNone(token_info["expiry_time"])
+ self.assertEqual(token_info["pending"], 0)
+ self.assertEqual(token_info["completed"], 0)
+
+ def test_list_invalid_query_parameter(self):
+ """Test with `valid` query parameter not `true` or `false`."""
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=x",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _test_list_query_parameter(self, valid: str):
+ """Helper used to test both valid=true and valid=false."""
+ # Create 2 valid and 2 invalid tokens.
+ now = self.hs.get_clock().time_msec()
+ # Create always valid token
+ valid1 = self._new_token()
+ # Create token that hasn't been used up
+ valid2 = self._new_token(uses_allowed=1)
+ # Create token that has expired
+ invalid1 = self._new_token(expiry_time=now - 10000)
+ # Create token that has been used up but hasn't expired
+ invalid2 = self._new_token(
+ uses_allowed=2,
+ pending=1,
+ completed=1,
+ expiry_time=now + 1000000,
+ )
+
+ if valid == "true":
+ tokens = [valid1, valid2]
+ else:
+ tokens = [invalid1, invalid2]
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=" + valid,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+ token_info_1 = channel.json_body["registration_tokens"][0]
+ token_info_2 = channel.json_body["registration_tokens"][1]
+ self.assertIn(token_info_1["token"], tokens)
+ self.assertIn(token_info_2["token"], tokens)
+
+ def test_list_valid(self):
+ """Test listing just valid tokens."""
+ self._test_list_query_parameter(valid="true")
+
+ def test_list_invalid(self):
+ """Test listing just invalid tokens."""
+ self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 17ec8bfd3b..40e032df7f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -22,130 +22,13 @@ from parameterized import parameterized_class
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
-from synapse.rest.client.v1 import directory, events, login, room
+from synapse.rest.client import directory, events, login, room
from tests import unittest
"""Tests admin REST events for /rooms paths."""
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/shutdown_room/" + room_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/shutdown_room/" + room_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room."""
-
- url = "rooms/%s/initialSync" % (room_id,)
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
@parameterized_class(
("method", "url_template"),
[
@@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API."""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in PURGE_TABLES:
- count = self.get_success(
- self.store.db_pool.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
-
-
class RoomTestCase(unittest.HomeserverTestCase):
"""Test /room admin API."""
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
new file mode 100644
index 0000000000..fbceba3254
--- /dev/null
+++ b/tests/rest/admin/test_server_notice.py
@@ -0,0 +1,450 @@
+# Copyright 2021 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login, room, sync
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class ServerNoticeTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
+ self.pagination_handler = hs.get_pagination_handler()
+ self.server_notices_manager = self.hs.get_server_notices_manager()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/send_server_notice"
+
+ def test_no_auth(self):
+ """Try to send a server notice without authentication."""
+ channel = self.make_request("POST", self.url)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """If the user is not a server admin, an error is returned."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_does_not_exist(self):
+ """Tests that a lookup for a user that does not exist returns a 404"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": "@unknown_person:test", "content": ""},
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": "@unknown_person:unknown_domain",
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Server notices can only be sent to local users", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_invalid_parameter(self):
+ """If parameters are invalid, an error is returned."""
+
+ # no content, no user
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ # no content
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no body
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": ""},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'body' not in content", channel.json_body["error"])
+
+ # no msgtype
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": {"body": ""}},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+
+ def test_server_notice_disabled(self):
+ """Tests that server returns error if server notice is disabled"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(
+ "Server notices are not enabled on this server", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice(self):
+ """
+ Tests that sending two server notices is successfully,
+ the server uses the same room and do not send messages twice.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token)
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # invalidate cache of server notices room_ids
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has no new invites or memberships
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 2)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ self.assertEqual(messages[1]["content"]["body"], "test msg two")
+ self.assertEqual(messages[1]["sender"], "@notices:test")
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_leave_room(self):
+ """
+ Tests that sending a server notices is successfully.
+ The user leaves the room and the second message appears
+ in a new room.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # user leaves the romm
+ self.helper.leave(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id the user gets the message
+ # in old room
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # room has the same id
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_delete_room(self):
+ """
+ Tests that the user get server notice in a new room
+ after the first server notice room was deleted.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # shut down and purge room
+ self.get_success(
+ self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
+ )
+ self.get_success(self.pagination_handler.purge_room(first_room_id))
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(first_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id it gives an error
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get message
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # second room has new ID
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ def _check_invite_and_join_status(
+ self, user_id: str, expected_invites: int, expected_memberships: int
+ ) -> RoomsForUser:
+ """Check invite and room membership status of a user.
+
+ Args
+ user_id: user to check
+ expected_invites: number of expected invites of this user
+ expected_memberships: number of expected room memberships of this user
+ Returns
+ room_ids from the rooms that the user is invited
+ """
+
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(user_id)
+ )
+ self.assertEqual(expected_invites, len(invited_rooms))
+
+ room_ids = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(expected_memberships, len(room_ids))
+
+ return invited_rooms
+
+ def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]:
+ """
+ Do a sync and get messages of a room.
+
+ Args
+ room_id: room that contains the messages
+ token: access token of user
+
+ Returns
+ list of messages contained in the room
+ """
+ channel = self.make_request(
+ "GET", "/_matrix/client/r0/sync", access_token=token
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Get the messages
+ room = channel.json_body["rooms"]["join"][room_id]
+ messages = [
+ x for x in room["timeline"]["events"] if x["type"] == "m.room.message"
+ ]
+ return messages
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 79cac4266b..5cd82209c4 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional
import synapse.rest.admin
from synapse.api.errors import Codes
-from synapse.rest.client.v1 import login
+from synapse.rest.client import login
from tests import unittest
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 42f50c0921..ee204c404b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -15,17 +15,20 @@
import hashlib
import hmac
import json
+import os
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
from unittest.mock import Mock, patch
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
-from synapse.rest.client.v1 import login, logout, profile, room
-from synapse.rest.client.v2_alpha import devices, sync
+from synapse.rest.client import devices, login, logout, profile, room, sync
+from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.types import JsonDict, UserID
from tests import unittest
@@ -72,7 +75,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -104,7 +107,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -112,7 +115,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self):
@@ -166,7 +169,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self):
@@ -191,13 +194,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self):
@@ -219,7 +222,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -230,28 +233,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce()})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -262,28 +265,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = json.dumps({"nonce": nonce(), "username": "a"})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -301,7 +304,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self):
@@ -322,11 +325,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
@@ -348,11 +351,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
@@ -374,7 +377,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
@@ -399,11 +402,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config(
@@ -449,7 +452,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request("POST", self.url, body.encode("utf8"))
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -638,7 +641,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
@@ -1085,7 +1088,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1236,56 +1239,114 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
- def test_get_user(self):
+ def test_invalid_parameter(self):
"""
- Test a simple get of a user.
+ If parameters are invalid, an error is returned.
"""
+
+ # admin not bool
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"admin": "not_bool"},
)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual("User", channel.json_body["displayname"])
- self._check_fields(channel.json_body)
+ # deactivated not bool
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": "not_bool"},
+ )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
- def test_get_user_with_sso(self):
- """
- Test get a user with SSO details.
- """
- self.get_success(
- self.store.record_user_external_id(
- "auth_provider1", "external_id1", self.other_user
- )
+ # password not str
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"password": True},
)
- self.get_success(
- self.store.record_user_external_id(
- "auth_provider2", "external_id2", self.other_user
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # password not length
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"password": "x" * 513},
)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ # user_type not valid
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"user_type": "new type"},
)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(
- "external_id1", channel.json_body["external_ids"][0]["external_id"]
+ # external_ids not valid
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
+ },
)
- self.assertEqual(
- "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(
- "external_id2", channel.json_body["external_ids"][1]["external_id"]
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # threepids not valid
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(
- "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"]
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": {"address": "value"}},
)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_get_user(self):
+ """
+ Test a simple get of a user.
+ """
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
def test_create_server_admin(self):
@@ -1349,6 +1410,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"admin": False,
"displayname": "Bob's name",
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "external_ids": [
+ {
+ "external_id": "external_id1",
+ "auth_provider": "auth_provider1",
+ },
+ ],
"avatar_url": "mxc://fibble/wibble",
}
@@ -1364,6 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(1, len(channel.json_body["threepids"]))
+ self.assertEqual(
+ "external_id1", channel.json_body["external_ids"][0]["external_id"]
+ )
+ self.assertEqual(
+ "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
+ )
+ self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1603,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test setting threepid for an other user.
"""
- # Delete old and add new threepid to user
+ # Add two threepids to user
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob1@bob.bob"},
+ {"medium": "email", "address": "bob2@bob.bob"},
+ ],
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
+ # result does not always have the same sort order, therefore it becomes sorted
+ sorted_result = sorted(
+ channel.json_body["threepids"], key=lambda k: k["address"]
+ )
+ self.assertEqual("email", sorted_result[0]["medium"])
+ self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+ self.assertEqual("email", sorted_result[1]["medium"])
+ self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Set a new and remove a threepid
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob2@bob.bob"},
+ {"medium": "email", "address": "bob3@bob.bob"},
+ ],
+ },
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1625,8 +1735,122 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Remove threepids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": []},
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self._check_fields(channel.json_body)
+
+ def test_set_external_id(self):
+ """
+ Test setting external id for an other user.
+ """
+
+ # Add two external_ids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "external_ids": [
+ {
+ "external_id": "external_id1",
+ "auth_provider": "auth_provider1",
+ },
+ {
+ "external_id": "external_id2",
+ "auth_provider": "auth_provider2",
+ },
+ ]
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
+ # result does not always have the same sort order, therefore it becomes sorted
+ self.assertEqual(
+ sorted(channel.json_body["external_ids"], key=lambda k: k["auth_provider"]),
+ [
+ {"auth_provider": "auth_provider1", "external_id": "external_id1"},
+ {"auth_provider": "auth_provider2", "external_id": "external_id2"},
+ ],
+ )
+ self._check_fields(channel.json_body)
+
+ # Set a new and remove an external_id
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "external_ids": [
+ {
+ "external_id": "external_id2",
+ "auth_provider": "auth_provider2",
+ },
+ {
+ "external_id": "external_id3",
+ "auth_provider": "auth_provider3",
+ },
+ ]
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
+ self.assertEqual(
+ channel.json_body["external_ids"],
+ [
+ {"auth_provider": "auth_provider2", "external_id": "external_id2"},
+ {"auth_provider": "auth_provider3", "external_id": "external_id3"},
+ ],
+ )
+ self._check_fields(channel.json_body)
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
+ self.assertEqual(
+ channel.json_body["external_ids"],
+ [
+ {"auth_provider": "auth_provider2", "external_id": "external_id2"},
+ {"auth_provider": "auth_provider3", "external_id": "external_id3"},
+ ],
+ )
+ self._check_fields(channel.json_body)
+
+ # Remove external_ids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"external_ids": []},
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(0, len(channel.json_body["external_ids"]))
def test_deactivate_user(self):
"""
@@ -2180,7 +2404,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self):
"""
@@ -2249,6 +2473,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
+ self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -2258,37 +2483,34 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self):
- """
- Try to list media of an user without authentication.
- """
- channel = self.make_request("GET", self.url, b"{}")
+ @parameterized.expand(["GET", "DELETE"])
+ def test_no_auth(self, method: str):
+ """Try to list media of an user without authentication."""
+ channel = self.make_request(method, self.url, {})
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error is returned.
- """
+ @parameterized.expand(["GET", "DELETE"])
+ def test_requester_is_no_admin(self, method: str):
+ """If the user is not a server admin, an error is returned."""
other_user_token = self.login("user", "pass")
channel = self.make_request(
- "GET",
+ method,
self.url,
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
- """
- Tests that a lookup for a user that does not exist returns a 404
- """
+ @parameterized.expand(["GET", "DELETE"])
+ def test_user_does_not_exist(self, method: str):
+ """Tests that a lookup for a user that does not exist returns a 404"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
- "GET",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -2296,25 +2518,22 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_user_is_not_local(self):
- """
- Tests that a lookup for a user that is not a local returns a 400
- """
+ @parameterized.expand(["GET", "DELETE"])
+ def test_user_is_not_local(self, method: str):
+ """Tests that a lookup for a user that is not a local returns a 400"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
- "GET",
+ method,
url,
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ self.assertEqual("Can only look up local users", channel.json_body["error"])
- def test_limit(self):
- """
- Testing list of media with limit
- """
+ def test_limit_GET(self):
+ """Testing list of media with limit"""
number_media = 20
other_user_tok = self.login("user", "pass")
@@ -2326,16 +2545,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["media"])
- def test_from(self):
- """
- Testing list of media with a defined starting point (from)
- """
+ def test_limit_DELETE(self):
+ """Testing delete of media with limit"""
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media_for_user(other_user_tok, number_media)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "?limit=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], 5)
+ self.assertEqual(len(channel.json_body["deleted_media"]), 5)
+
+ def test_from_GET(self):
+ """Testing list of media with a defined starting point (from)"""
number_media = 20
other_user_tok = self.login("user", "pass")
@@ -2347,16 +2581,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"])
- def test_limit_and_from(self):
- """
- Testing list of media with a defined starting point and limit
- """
+ def test_from_DELETE(self):
+ """Testing delete of media with a defined starting point (from)"""
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media_for_user(other_user_tok, number_media)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "?from=5",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], 15)
+ self.assertEqual(len(channel.json_body["deleted_media"]), 15)
+
+ def test_limit_and_from_GET(self):
+ """Testing list of media with a defined starting point and limit"""
number_media = 20
other_user_tok = self.login("user", "pass")
@@ -2368,59 +2617,78 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10)
self._check_fields(channel.json_body["media"])
- def test_invalid_parameter(self):
- """
- If parameters are invalid, an error is returned.
- """
+ def test_limit_and_from_DELETE(self):
+ """Testing delete of media with a defined starting point and limit"""
+
+ number_media = 20
+ other_user_tok = self.login("user", "pass")
+ self._create_media_for_user(other_user_tok, number_media)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "?from=5&limit=10",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["deleted_media"]), 10)
+
+ @parameterized.expand(["GET", "DELETE"])
+ def test_invalid_parameter(self, method: str):
+ """If parameters are invalid, an error is returned."""
# unkown order_by
channel = self.make_request(
- "GET",
+ method,
self.url + "?order_by=bar",
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid search order
channel = self.make_request(
- "GET",
+ method,
self.url + "?dir=bar",
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# negative limit
channel = self.make_request(
- "GET",
+ method,
self.url + "?limit=-5",
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
channel = self.make_request(
- "GET",
+ method,
self.url + "?from=-5",
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self):
"""
Testing that `next_token` appears at the right place
+
+ For deleting media `next_token` is not useful, because
+ after deleting media the media has a new order.
"""
number_media = 20
@@ -2435,7 +2703,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -2448,7 +2716,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -2461,7 +2729,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -2475,12 +2743,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
- def test_user_has_no_media(self):
+ def test_user_has_no_media_GET(self):
"""
Tests that a normal lookup for media is successfully
if user has no media created
@@ -2496,11 +2764,24 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
- def test_get_media(self):
+ def test_user_has_no_media_DELETE(self):
"""
- Tests that a normal lookup for media is successfully
+ Tests that a delete is successful if user has no media
"""
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["total"])
+ self.assertEqual(0, len(channel.json_body["deleted_media"]))
+
+ def test_get_media(self):
+ """Tests that a normal lookup for media is successful"""
+
number_media = 5
other_user_tok = self.login("user", "pass")
self._create_media_for_user(other_user_tok, number_media)
@@ -2517,6 +2798,35 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["media"])
+ def test_delete_media(self):
+ """Tests that a normal delete of media is successful"""
+
+ number_media = 5
+ other_user_tok = self.login("user", "pass")
+ media_ids = self._create_media_for_user(other_user_tok, number_media)
+
+ # Test if the file exists
+ local_paths = []
+ for media_id in media_ids:
+ local_path = self.filepaths.local_media_filepath(media_id)
+ self.assertTrue(os.path.exists(local_path))
+ local_paths.append(local_path)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(number_media, channel.json_body["total"])
+ self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
+ self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
+
+ # Test if the file is deleted
+ for local_path in local_paths:
+ self.assertFalse(os.path.exists(local_path))
+
def test_order_by(self):
"""
Testing order list with parameter `order_by`
@@ -2622,13 +2932,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
[media2] + sorted([media1, media3]), "safe_from_quarantine", "b"
)
- def _create_media_for_user(self, user_token: str, number_media: int):
+ def _create_media_for_user(self, user_token: str, number_media: int) -> List[str]:
"""
Create a number of media for a specific user
Args:
user_token: Access token of the user
number_media: Number of media to be created for the user
+ Returns:
+ List of created media ID
"""
+ media_ids = []
for _ in range(number_media):
# file size is 67 Byte
image_data = unhexlify(
@@ -2637,7 +2950,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
b"0a2db40000000049454e44ae426082"
)
- self._create_media_and_access(user_token, image_data)
+ media_ids.append(self._create_media_and_access(user_token, image_data))
+
+ return media_ids
def _create_media_and_access(
self,
@@ -2680,7 +2995,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
200,
channel.code,
msg=(
- "Expected to receive a 200 on accessing media: %s" % server_and_media_id
+ f"Expected to receive a 200 on accessing media: {server_and_media_id}"
),
)
@@ -2718,12 +3033,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url = self.url + "?"
if order_by is not None:
- url += "order_by=%s&" % (order_by,)
+ url += f"order_by={order_by}&"
if dir is not None and dir in ("b", "f"):
- url += "dir=%s" % (dir,)
+ url += f"dir={dir}"
channel = self.make_request(
"GET",
- url.encode("ascii"),
+ url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -2762,7 +3077,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
def test_no_auth(self):
@@ -2803,7 +3118,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
@@ -2815,11 +3130,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
@@ -2829,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self):
"""Tests that the target user calling `/logout/all` does *not* expire
@@ -2840,17 +3155,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't
channel = self.make_request(
@@ -2867,13 +3182,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
@@ -2883,7 +3198,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config(
{
@@ -3243,7 +3558,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
+ self.assertEqual("Can only look up local users", channel.json_body["error"])
channel = self.make_request(
"POST",
@@ -3279,7 +3594,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3290,7 +3605,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3301,7 +3616,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3312,7 +3627,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self):
@@ -3337,7 +3652,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
@@ -3351,7 +3666,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3362,7 +3677,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"])
@@ -3373,7 +3688,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3383,7 +3698,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3393,7 +3708,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3403,6 +3718,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
new file mode 100644
index 0000000000..4e1c49c28b
--- /dev/null
+++ b/tests/rest/admin/test_username_available.py
@@ -0,0 +1,62 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import synapse.rest.admin
+from synapse.api.errors import Codes, SynapseError
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class UsernameAvailableTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+ url = "/_synapse/admin/v1/username_available"
+
+ def prepare(self, reactor, clock, hs):
+ self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ async def check_username(username):
+ if username == "allowed":
+ return True
+ raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+
+ handler = self.hs.get_registration_handler()
+ handler.check_username = check_username
+
+ def test_username_available(self):
+ """
+ The endpoint should return a 200 response if the username does not exist
+ """
+
+ url = "%s?username=%s" % (self.url, "allowed")
+ channel = self.make_request("GET", url, None, self.admin_user_tok)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertTrue(channel.json_body["available"])
+
+ def test_username_unavailable(self):
+ """
+ The endpoint should return a 200 response if the username does not exist
+ """
+
+ url = "%s?username=%s" % (self.url, "disallowed")
+ channel = self.make_request("GET", url, None, self.admin_user_tok)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
+ self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py
index 317a2287e3..b946fca8b3 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -25,8 +25,7 @@ import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes, HttpResponseException
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
@@ -47,12 +46,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
config = self.default_config()
# Email config.
- self.email_attempts = []
-
- async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
- self.email_attempts.append(msg)
- return
-
config["email"] = {
"enable_notifs": False,
"template_dir": os.path.abspath(
@@ -67,7 +60,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
}
config["public_baseurl"] = "https://example.com"
- hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ hs = self.setup_test_homeserver(config=config)
+
+ async def sendmail(
+ reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
+ ):
+ self.email_attempts.append(msg)
+
+ self.email_attempts = []
+ hs.get_send_email_handler()._sendmail = sendmail
+
return hs
def prepare(self, reactor, clock, hs):
@@ -511,11 +513,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
config = self.default_config()
# Email config.
- self.email_attempts = []
-
- async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs):
- self.email_attempts.append(msg)
-
config["email"] = {
"enable_notifs": False,
"template_dir": os.path.abspath(
@@ -530,7 +527,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}
config["public_baseurl"] = "https://example.com"
- self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ self.hs = self.setup_test_homeserver(config=config)
+
+ async def sendmail(
+ reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
+ ):
+ self.email_attempts.append(msg)
+
+ self.email_attempts = []
+ self.hs.get_send_email_handler()._sendmail = sendmail
+
return self.hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py
index 6b90f838b6..e2fcbdc63a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -19,14 +19,13 @@ from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import account, auth, devices, register
+from synapse.rest.client import account, auth, devices, login, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import JsonDict, UserID
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py
index f80f48a455..422361b62a 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -13,8 +13,7 @@
# limitations under the License.
import synapse.rest.admin
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import capabilities
+from synapse.rest.client import capabilities, login
from tests import unittest
from tests.unittest import override_config
@@ -31,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
self.config = hs.config
self.auth_handler = hs.get_auth_handler()
return hs
+ def prepare(self, reactor, clock, hs):
+ self.localpart = "user"
+ self.password = "pass"
+ self.user = self.register_user(self.localpart, self.password)
+
def test_check_auth_required(self):
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
def test_get_room_version_capabilities(self):
- self.register_user("user", "pass")
- access_token = self.login("user", "pass")
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -58,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
)
def test_get_change_password_capabilities_password_login(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
- access_token = self.login(user, password)
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -71,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -88,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -103,13 +96,86 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
- def test_get_does_not_include_msc3244_fields_by_default(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
+ def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+ """Test that per default msc3283 is disabled server returns `m.change_password`."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+ self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+ self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+ @override_config({"experimental_features": {"msc3283_enabled": True}})
+ def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+ """Test if msc3283 is enabled server returns capabilities."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled(self):
+ """Test if set displayname is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+ @override_config(
+ {
+ "enable_3pid_changes": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_change_3pid_capabilities_3pid_disabled(self):
+ """Test if change 3pid is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+ @override_config({"experimental_features": {"msc3244_enabled": False}})
+ def test_get_does_not_include_msc3244_fields_when_disabled(self):
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -121,14 +187,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
"org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
)
- @override_config({"experimental_features": {"msc3244_enabled": True}})
def test_get_does_include_msc3244_fields_when_enabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 5cc62a910a..65c58ce70a 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -16,7 +16,7 @@ import os
import synapse.rest.admin
from synapse.api.urls import ConsentURIBuilder
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.rest.consent import consent_resource
from tests import unittest
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py
index 8ed470490b..d2181ea907 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -15,7 +15,7 @@
import json
from synapse.rest import admin
-from synapse.rest.client.v1 import directory, login, room
+from synapse.rest.client import directory, login, room
from synapse.types import RoomAlias
from synapse.util.stringutils import random_string
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index eec0fc01f9..3d7aa8ec86 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -13,7 +13,7 @@
# limitations under the License.
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest import admin
-from synapse.rest.client.v1 import room
+from synapse.rest.client import room
from tests import unittest
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py
index 2789d51546..a90294003e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -17,7 +17,7 @@
from unittest.mock import Mock
import synapse.rest.admin
-from synapse.rest.client.v1 import events, login, room
+from synapse.rest.client import events, login, room
from tests import unittest
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py
index c7e47725b7..475c6bed3d 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.errors import Codes
-from synapse.rest.client.v2_alpha import filter
+from synapse.rest.client import filter
from tests import unittest
diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py
new file mode 100644
index 0000000000..ad0425ae65
--- /dev/null
+++ b/tests/rest/client/test_groups.py
@@ -0,0 +1,56 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest.client import groups, room
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class GroupsTestCase(unittest.HomeserverTestCase):
+ user_id = "@alice:test"
+ room_creator_user_id = "@bob:test"
+
+ servlets = [room.register_servlets, groups.register_servlets]
+
+ @override_config({"enable_group_creation": True})
+ def test_rooms_limited_by_visibility(self):
+ group_id = "+spqr:test"
+
+ # Alice creates a group
+ channel = self.make_request("POST", "/create_group", {"localpart": "spqr"})
+ self.assertEquals(channel.code, 200, msg=channel.text_body)
+ self.assertEquals(channel.json_body, {"group_id": group_id})
+
+ # Bob creates a private room
+ room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False)
+ self.helper.auth_user_id = self.room_creator_user_id
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": "bob's secret room"}, tok=None
+ )
+ self.helper.auth_user_id = self.user_id
+
+ # Alice adds the room to her group.
+ channel = self.make_request(
+ "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {}
+ )
+ self.assertEquals(channel.code, 200, msg=channel.text_body)
+ self.assertEquals(channel.json_body, {})
+
+ # Alice now tries to retrieve the room list of the space.
+ channel = self.make_request("GET", f"/groups/{group_id}/rooms")
+ self.assertEquals(channel.code, 200, msg=channel.text_body)
+ self.assertEquals(
+ channel.json_body, {"chunk": [], "total_room_count_estimate": 0}
+ )
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 478296ba0e..ca2e8ff8ef 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,7 +15,7 @@
import json
import synapse.rest.admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests import unittest
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
new file mode 100644
index 0000000000..d7fa635eae
--- /dev/null
+++ b/tests/rest/client/test_keys.py
@@ -0,0 +1,91 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ keys.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def test_rejects_device_id_ice_key_outside_of_list(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: "device_id1",
+ },
+ },
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_rejects_device_key_given_as_map_to_bool(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: {
+ "device_id1": True,
+ },
+ },
+ },
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_requires_device_key(self):
+ """`device_keys` is required. We should complain if it's missing."""
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {},
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py
index 7eba69642a..5b2243fe52 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -24,16 +24,15 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices, register
-from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import create_requester
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
-from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 6f07ff6cbb..3cf5871899 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
@@ -17,8 +17,7 @@ import json
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.rest import admin
-from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import account, password_policy, register
+from synapse.rest.client import account, login, password_policy, register
from tests import unittest
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index ba5ad47df5..c0de4c93a8 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.errors import Codes
+from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client import login, room, sync
from tests.unittest import HomeserverTestCase
@@ -204,3 +205,79 @@ class PowerLevelsTestCase(HomeserverTestCase):
tok=self.admin_access_token,
expect_code=200, # expect success
)
+
+ def test_cannot_set_string_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL "0"
+ room_power_levels["users"].update({self.user_user_id: "0"})
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_large_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL above the max safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MAX_INT + 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_small_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL below the minimum safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MIN_INT - 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py
index 597e4c67de..1d152352d1 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from twisted.internet import defer
from synapse.handlers.presence import PresenceHandler
-from synapse.rest.client.v1 import presence
+from synapse.rest.client import presence
from synapse.types import UserID
from tests import unittest
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py
index 165ad33fb7..2860579c2e 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -14,7 +14,7 @@
"""Tests REST events for /profile paths."""
from synapse.rest import admin
-from synapse.rest.client.v1 import login, profile, room
+from synapse.rest.client import login, profile, room
from tests import unittest
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index d077616082..d0ce91ccd9 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
@@ -13,7 +13,7 @@
# limitations under the License.
import synapse
from synapse.api.errors import Codes
-from synapse.rest.client.v1 import login, push_rule, room
+from synapse.rest.client import login, push_rule, room
from tests.unittest import HomeserverTestCase
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index dfd85221d0..433d715f69 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -13,8 +13,7 @@
# limitations under the License.
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client import login, room, sync
from tests.unittest import HomeserverTestCase
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py
index 1cad5f00eb..9f3ab2c985 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -23,8 +23,8 @@ import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
-from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import account, account_validity, register, sync
+from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@@ -205,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_requires_token(self):
+ username = "kermit"
+ device_id = "frogfone"
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params = {
+ "username": username,
+ "password": "monkey",
+ "device_id": device_id,
+ }
+
+ # Request without auth to get flows and session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+ # Synapse adds a dummy stage to differentiate flows where otherwise one
+ # flow would be a subset of another flow.
+ self.assertCountEqual(
+ [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+ (f["stages"] for f in flows),
+ )
+ session = channel.json_body["session"]
+
+ # Do the registration token stage and check it has completed
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ # Do the m.login.dummy stage and check registration was successful
+ params["auth"] = {
+ "type": LoginType.DUMMY,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ det_data = {
+ "user_id": f"@{username}:{self.hs.hostname}",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+
+ # Check the `completed` counter has been incremented and pending is 0
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["completed"], 1)
+ self.assertEquals(res["pending"], 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_invalid(self):
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ }
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Test with token param missing (invalid)
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with non-string (invalid)
+ params["auth"]["token"] = 1234
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with unknown token (invalid)
+ params["auth"]["token"] = "1234"
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_limit_uses(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ # Create token that can be used once
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ # Do 2 requests without auth to get two session IDs
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with session1 and check `pending` is 1
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Repeat request to make sure pending isn't increased again
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 1)
+
+ # Check auth fails when using token with session2
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Check pending=0 and completed=1
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["pending"], 0)
+ self.assertEquals(res["completed"], 1)
+
+ # Check auth still fails when using token with session2
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_expiry(self):
+ token = "abcd"
+ now = self.hs.get_clock().time_msec()
+ store = self.hs.get_datastore()
+ # Create token that expired yesterday
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": now - 24 * 60 * 60 * 1000,
+ },
+ )
+ )
+ params = {"username": "kermit", "password": "monkey"}
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Check authentication fails with expired token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Update token so it expires tomorrow
+ self.get_success(
+ store.db_pool.simple_update_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+ )
+ )
+
+ # Check authentication succeeds
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry(self):
+ """Test `pending` is decremented when an uncompleted session expires."""
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do 2 requests without auth to get two session IDs
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with both sessions
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params2))
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ # Check `result` of registration token stage for session1 is `True`
+ result1 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session1,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertTrue(db_to_json(result1))
+
+ # Check `result` for session2 is the token used
+ result2 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session2,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertEquals(db_to_json(result2), token)
+
+ # Delete both sessions (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
+ # Check pending is now 0
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry_deleted_token(self):
+ """Test session expiry doesn't break when the token is deleted.
+
+ 1. Start but don't complete UIA with a registration token
+ 2. Delete the token from the database
+ 3. Expire the session
+ """
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do request without auth to get a session ID
+ params = {"username": "kermit", "password": "monkey"}
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Use token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params))
+
+ # Delete token
+ self.get_success(
+ store.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ )
+ )
+
+ # Delete session (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -509,10 +874,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
}
# Email config.
- self.email_attempts = []
-
- async def sendmail(*args, **kwargs):
- self.email_attempts.append((args, kwargs))
config["email"] = {
"enable_notifs": True,
@@ -532,7 +893,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
}
config["public_baseurl"] = "aaa"
- self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
+ self.hs = self.setup_test_homeserver(config=config)
+
+ async def sendmail(*args, **kwargs):
+ self.email_attempts.append((args, kwargs))
+
+ self.email_attempts = []
+ self.hs.get_send_email_handler()._sendmail = sendmail
self.store = self.hs.get_datastore()
@@ -743,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [register.register_servlets]
+ url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+ def default_config(self):
+ config = super().default_config()
+ config["registration_requires_token"] = True
+ return config
+
+ def test_GET_token_valid(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], True)
+
+ def test_GET_token_invalid(self):
+ token = "1234"
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], False)
+
+ @override_config(
+ {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+ )
+ def test_GET_ratelimiting(self):
+ token = "1234"
+
+ for i in range(0, 6):
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+
+ if i == 5:
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ retry_after_ms = int(channel.json_body["retry_after_ms"])
+ else:
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py
index 2e2f94742e..02b5e9a8d0 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -19,8 +19,7 @@ from typing import Optional
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import register, relations
+from synapse.rest.client import login, register, relations, room
from tests import unittest
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py
index a76a6fef1e..ee6b0b9ebf 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -15,8 +15,7 @@
import json
import synapse.rest.admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import report_event
+from synapse.rest.client import login, report_event, room
from tests import unittest
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index e1a6e73e17..b58452195a 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -15,7 +15,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.visibility import filter_events_for_client
from tests import unittest
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py
index 3df070c936..50100a5ae4 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -19,15 +19,17 @@
import json
from typing import Iterable
-from unittest.mock import Mock
+from unittest.mock import Mock, call
from urllib import parse as urlparse
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.errors import HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
-from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account
+from synapse.rest.client import account, directory, login, profile, room, sync
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
@@ -379,6 +381,8 @@ class RoomPermissionsTestCase(RoomBase):
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
+ servlets = RoomBase.servlets + [sync.register_servlets]
+
user_id = "@sid1:red"
def test_get_member_list(self):
@@ -395,6 +399,86 @@ class RoomsMemberListTestCase(RoomBase):
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEquals(403, channel.code, msg=channel.result["body"])
+ def test_get_member_list_no_permission_with_at_token(self):
+ """
+ Tests that a stranger to the room cannot get the member list
+ (in the case that they use an at token).
+ """
+ room_id = self.helper.create_room_as("@someone.else:red")
+
+ # first sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEquals(200, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ # check that permission is denied for @sid1:red to get the
+ # memberships of @someone.else:red's room.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{room_id}/members?at={sync_token}",
+ )
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
+
+ def test_get_member_list_no_permission_former_member(self):
+ """
+ Tests that a former member of the room can not get the member list.
+ """
+ # create a room, invite the user and the user joins
+ room_id = self.helper.create_room_as("@alice:red")
+ self.helper.invite(room_id, "@alice:red", self.user_id)
+ self.helper.join(room_id, self.user_id)
+
+ # check that the user can see the member list to start with
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
+
+ # ban the user
+ self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
+
+ # check the user can no longer see the member list
+ channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
+
+ def test_get_member_list_no_permission_former_member_with_at_token(self):
+ """
+ Tests that a former member of the room can not get the member list
+ (in the case that they use an at token).
+ """
+ # create a room, invite the user and the user joins
+ room_id = self.helper.create_room_as("@alice:red")
+ self.helper.invite(room_id, "@alice:red", self.user_id)
+ self.helper.join(room_id, self.user_id)
+
+ # sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEquals(200, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ # check that the user can see the member list to start with
+ channel = self.make_request(
+ "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
+ )
+ self.assertEquals(200, channel.code, msg=channel.result["body"])
+
+ # ban the user (Note: the user is actually allowed to see this event and
+ # state so that they know they're banned!)
+ self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
+
+ # invite a third user and let them join
+ self.helper.invite(room_id, "@alice:red", "@bob:red")
+ self.helper.join(room_id, "@bob:red")
+
+ # now, with the original user, sync again to get a new at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEquals(200, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ # check the user can no longer see the updated member list
+ channel = self.make_request(
+ "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
+ )
+ self.assertEquals(403, channel.code, msg=channel.result["body"])
+
def test_get_member_list_mixed_memberships(self):
room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator)
@@ -1124,6 +1208,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
+class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
+ """Test that we correctly fallback to local filtering if a remote server
+ doesn't support search.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(federation_client=Mock())
+
+ def prepare(self, reactor, clock, hs):
+ self.register_user("user", "pass")
+ self.token = self.login("user", "pass")
+
+ self.federation_client = hs.get_federation_client()
+
+ def test_simple(self):
+ "Simple test for searching rooms over federation"
+ self.federation_client.get_public_rooms.side_effect = (
+ lambda *a, **k: defer.succeed({})
+ )
+
+ search_filter = {"generic_search_term": "foobar"}
+
+ channel = self.make_request(
+ "POST",
+ b"/_matrix/client/r0/publicRooms?server=testserv",
+ content={"filter": search_filter},
+ access_token=self.token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.federation_client.get_public_rooms.assert_called_once_with(
+ "testserv",
+ limit=100,
+ since_token=None,
+ search_filter=search_filter,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ )
+
+ def test_fallback(self):
+ "Test that searching public rooms over federation falls back if it gets a 404"
+
+ # The `get_public_rooms` should be called again if the first call fails
+ # with a 404, when using search filters.
+ self.federation_client.get_public_rooms.side_effect = (
+ HttpResponseException(404, "Not Found", b""),
+ defer.succeed({}),
+ )
+
+ search_filter = {"generic_search_term": "foobar"}
+
+ channel = self.make_request(
+ "POST",
+ b"/_matrix/client/r0/publicRooms?server=testserv",
+ content={"filter": search_filter},
+ access_token=self.token,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.federation_client.get_public_rooms.assert_has_calls(
+ [
+ call(
+ "testserv",
+ limit=100,
+ since_token=None,
+ search_filter=search_filter,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ),
+ call(
+ "testserv",
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ),
+ ]
+ )
+
+
class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index c9c99cc5d7..6db7062a8e 100644
--- a/tests/rest/client/v2_alpha/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -13,8 +13,7 @@
# limitations under the License.
from synapse.rest import admin
-from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import sendtodevice, sync
+from synapse.rest.client import login, sendtodevice, sync
from tests.unittest import HomeserverTestCase, override_config
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 288ee12888..6a0d9a82be 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -16,8 +16,13 @@ from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
-from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.rest.client import (
+ directory,
+ login,
+ profile,
+ room,
+ room_upgrade_rest_servlet,
+)
from synapse.types import UserID
from tests import unittest
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index cedb9614a8..283eccd53f 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.rest.admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import shared_rooms
+from synapse.rest.client import login, room, shared_rooms
from tests import unittest
from tests.server import FakeChannel
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py
index f6ae9ae181..95be369d4b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,8 +21,7 @@ from synapse.api.constants import (
ReadReceiptEventFields,
RelationTypes,
)
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import knock, read_marker, receipts, sync
+from synapse.rest.client import knock, login, read_marker, receipts, room, sync
from tests import unittest
from tests.federation.transport.test_knocking import (
@@ -418,6 +417,18 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's hidden read receipt
self.assertEqual(self._get_read_receipt(), None)
+ def test_read_receipt_with_empty_body(self):
+ # Send a message as the first user
+ res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+ # Send a read receipt for this message with an empty body
+ channel = self.make_request(
+ "POST",
+ "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]),
+ access_token=self.tok2,
+ )
+ self.assertEqual(channel.code, 200)
+
def _get_read_receipt(self):
"""Syncs and returns the read receipt."""
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 28dd47a28b..0ae4029640 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -19,7 +19,7 @@ from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.types import Requester, StateMap
from synapse.util.frozenutils import unfreeze
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py
index 44e22ca999..b54b004733 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -17,7 +17,7 @@
from unittest.mock import Mock
-from synapse.rest.client.v1 import room
+from synapse.rest.client import room
from synapse.types import UserID
from tests import unittest
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5f3f15fc57..72f976d8e2 100644
--- a/tests/rest/client/v2_alpha/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -15,8 +15,7 @@ from typing import Optional
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.rest.client import login, room, room_upgrade_rest_servlet
from tests import unittest
from tests.server import FakeChannel
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py
index fc2d35596e..954ad1a1fd 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/utils.py
@@ -47,10 +47,10 @@ class RestHelper:
def create_room_as(
self,
- room_creator: str = None,
+ room_creator: Optional[str] = None,
is_public: bool = True,
- room_version: str = None,
- tok: str = None,
+ room_version: Optional[str] = None,
+ tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[
diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py
deleted file mode 100644
index 5e83dba2ed..0000000000
--- a/tests/rest/client/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/rest/client/v2_alpha/__init__.py
+++ /dev/null
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2d6b49692e..2f7eebfe69 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -21,7 +21,7 @@ from unittest.mock import Mock
from urllib import parse
import attr
-from parameterized import parameterized_class
+from parameterized import parameterized, parameterized_class
from PIL import Image as Image
from twisted.internet import defer
@@ -30,7 +30,7 @@ from twisted.internet.defer import Deferred
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
from synapse.rest import admin
-from synapse.rest.client.v1 import login
+from synapse.rest.client import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -473,6 +473,43 @@ class MediaRepoTests(unittest.HomeserverTestCase):
},
)
+ @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
+ def test_same_quality(self, method, desired_size):
+ """Test that choosing between thumbnails with the same quality rating succeeds.
+
+ We are not particular about which thumbnail is chosen."""
+ self.assertIsNotNone(
+ self.thumbnail_resource._select_thumbnail(
+ desired_width=desired_size,
+ desired_height=desired_size,
+ desired_method=method,
+ desired_type=self.test_image.content_type,
+ # Provide two identical thumbnails which are guaranteed to have the same
+ # quality rating.
+ thumbnail_infos=[
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ },
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ },
+ ],
+ file_id=f"image{self.test_image.extension}",
+ url_cache=None,
+ server_name=None,
+ )
+ )
+
def test_x_robots_tag_header(self):
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index d3ef7bb4c6..7fa9027227 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -14,13 +14,14 @@
import json
import os
import re
-from unittest.mock import patch
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
+from synapse.config.oembed import OEmbedEndpointConfig
+
from tests import unittest
from tests.server import FakeTransport
@@ -81,6 +82,19 @@ class URLPreviewTests(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
+ # After the hs is created, modify the parsed oEmbed config (to avoid
+ # messing with files).
+ #
+ # Note that HTTP URLs are used to avoid having to deal with TLS in tests.
+ hs.config.oembed.oembed_patterns = [
+ OEmbedEndpointConfig(
+ api_endpoint="http://publish.twitter.com/oembed",
+ url_patterns=[
+ re.compile(r"http://twitter\.com/.+/status/.+"),
+ ],
+ )
+ ]
+
return hs
def prepare(self, reactor, clock, hs):
@@ -544,123 +558,101 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def test_oembed_photo(self):
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
- # Route the HTTP version to an HTTP endpoint so that the tests work.
- with patch.dict(
- "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
- {
- re.compile(
- r"http://twitter\.com/.+/status/.+"
- ): "http://publish.twitter.com/oembed",
- },
- clear=True,
- ):
-
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "photo",
- "url": "http://cdn.twitter.com/matrixdotorg",
- }
- oembed_content = json.dumps(result).encode("utf-8")
-
- end_content = (
- b"<html><head>"
- b"<title>Some Title</title>"
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- channel = self.make_request(
- "GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(oembed_content),)
- + oembed_content
- )
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
- self.pump()
-
- client = self.reactor.tcpClients[1][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
)
+ % (len(oembed_content),)
+ + oembed_content
+ )
- self.pump()
+ self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
- # Route the HTTP version to an HTTP endpoint so that the tests work.
- with patch.dict(
- "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
- {
- re.compile(
- r"http://twitter\.com/.+/status/.+"
- ): "http://publish.twitter.com/oembed",
- },
- clear=True,
- ):
-
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "rich",
- "html": "<div>Content Preview</div>",
- }
- end_content = json.dumps(result).encode("utf-8")
-
- channel = self.make_request(
- "GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
)
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
diff --git a/tests/server.py b/tests/server.py
index 6fddd3b305..b861c7b866 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -10,9 +10,10 @@ from zope.interface import implementer
from twisted.internet import address, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
-from twisted.internet.defer import Deferred, fail, succeed
+from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
+ IAddress,
IHostnameResolver,
IProtocol,
IPullProducer,
@@ -511,6 +512,9 @@ class FakeTransport:
will get called back for connectionLost() notifications etc.
"""
+ _peer_address: Optional[IAddress] = attr.ib(default=None)
+ """The value to be returend by getPeer"""
+
disconnecting = False
disconnected = False
connected = True
@@ -519,7 +523,7 @@ class FakeTransport:
autoflush = attr.ib(default=True)
def getPeer(self):
- return None
+ return self._peer_address
def getHost(self):
return None
@@ -572,7 +576,12 @@ class FakeTransport:
self.producerStreaming = streaming
def _produce():
- d = self.producer.resumeProducing()
+ if not self.producer:
+ # we've been unregistered
+ return
+ # some implementations of IProducer (for example, FileSender)
+ # don't return a deferred.
+ d = maybeDeferred(self.producer.resumeProducing)
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
if not streaming:
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index ac98259b7e..58b399a043 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -15,8 +15,7 @@
import os
import synapse.rest.admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client import login, room, sync
from tests import unittest
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 3245aa91ca..8701b5f7e3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -19,8 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client import login, room, sync
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 932970fd9a..a649e8c618 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -14,7 +14,10 @@
import json
from synapse.logging.context import LoggingContext
+from synapse.rest import admin
+from synapse.rest.client import login, room
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
@@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+
+
+class EventCacheTestCase(unittest.HomeserverTestCase):
+ """Test that the various layers of event cache works."""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store: EventsWorkerStore = hs.get_datastore()
+
+ self.user = self.register_user("user", "pass")
+ self.token = self.login(self.user, "pass")
+
+ self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+ res = self.helper.send(self.room, tok=self.token)
+ self.event_id = res["event_id"]
+
+ # Reset the event cache so the tests start with it empty
+ self.store._get_event_cache.clear()
+
+ def test_simple(self):
+ """Test that we cache events that we pull from the DB."""
+
+ with LoggingContext("test") as ctx:
+ self.get_success(self.store.get_event(self.event_id))
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+ def test_dedupe(self):
+ """Test that if we request the same event multiple times we only pull it
+ out once.
+ """
+
+ with LoggingContext("test") as ctx:
+ d = yieldable_gather_results(
+ self.store.get_event, [self.event_id, self.event_id]
+ )
+ self.get_success(d)
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
new file mode 100644
index 0000000000..ffee707153
--- /dev/null
+++ b/tests/storage/databases/main/test_room.py
@@ -0,0 +1,98 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.storage.databases.main.room import _BackgroundUpdates
+
+from tests.unittest import HomeserverTestCase
+
+
+class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+ self.token = self.login("foo", "pass")
+
+ def _generate_room(self) -> str:
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ return room_id
+
+ def test_background_populate_rooms_creator_column(self):
+ """Test that the background update to populate the rooms creator column
+ works properly.
+ """
+
+ # Insert a room without the creator
+ room_id = self._generate_room()
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"creator": None},
+ desc="test",
+ )
+ )
+
+ # Make sure the test is starting out with a room without a creator
+ room_creator_before = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="creator",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_creator_before, None)
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Make sure the background update filled in the room creator
+ room_creator_after = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="creator",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_creator_after, self.user_id)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 77c4fe721c..da98733ce8 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.storage import prepare_database
from synapse.types import UserID, create_requester
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index e57fce9694..1c2df54ecc 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
-from synapse.rest.client.v1 import login
+from synapse.rest.client import login
from tests import unittest
from tests.server import make_request
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index d87f124c26..93136f0717 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.storage.databases.main.events import _LinkMap
from synapse.types import create_requester
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index a0e2259478..c3fcf7e7b4 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -15,7 +15,9 @@
import attr
from parameterized import parameterized
+from synapse.api.room_versions import RoomVersions
from synapse.events import _EventInternalMetadata
+from synapse.util import json_encoder
import tests.unittest
import tests.utils
@@ -504,6 +506,61 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertSetEqual(difference, set())
+ def test_prune_inbound_federation_queue(self):
+ "Test that pruning of inbound federation queues work"
+
+ room_id = "some_room_id"
+
+ # Insert a bunch of events that all reference the previous one.
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ table="federation_inbound_events_staging",
+ values=[
+ {
+ "origin": "some_origin",
+ "room_id": room_id,
+ "received_ts": 0,
+ "event_id": f"$fake_event_id_{i + 1}",
+ "event_json": json_encoder.encode(
+ {"prev_events": [f"$fake_event_id_{i}"]}
+ ),
+ "internal_metadata": "{}",
+ }
+ for i in range(500)
+ ],
+ desc="test_prune_inbound_federation_queue",
+ )
+ )
+
+ # Calling prune once should return True, i.e. a prune happen. The second
+ # time it shouldn't.
+ pruned = self.get_success(
+ self.store.prune_staged_events_in_room(room_id, RoomVersions.V6)
+ )
+ self.assertTrue(pruned)
+
+ pruned = self.get_success(
+ self.store.prune_staged_events_in_room(room_id, RoomVersions.V6)
+ )
+ self.assertFalse(pruned)
+
+ # Assert that we only have a single event left in the queue, and that it
+ # is the last one.
+ count = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcol="COALESCE(COUNT(*), 0)",
+ desc="test_prune_inbound_federation_queue",
+ )
+ )
+ self.assertEqual(count, 1)
+
+ _, event_id = self.get_success(
+ self.store.get_next_staged_event_id_for_room(room_id)
+ )
+ self.assertEqual(event_id, "$fake_event_id_500")
+
@attr.s
class FakeEvent:
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 617bc8091f..f462a8b1c7 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -17,7 +17,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from tests.unittest import HomeserverTestCase
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index e5574063f1..22a77c3ccc 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,7 +13,7 @@
# limitations under the License.
from synapse.api.errors import NotFoundError, SynapseError
-from synapse.rest.client.v1 import room
+from synapse.rest.client import room
from tests.unittest import HomeserverTestCase
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 9fa968f6bb..c72dc40510 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -15,7 +15,7 @@
from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
from tests import unittest
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index 9be51f9ebd..6ff3ebb137 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -32,5 +32,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
db_pool = self.hs.get_datastores().databases[0]
# force txn limit to roll over at least once
- for i in range(0, 1001):
+ for _ in range(0, 1001):
self.get_success_or_raise(db_pool.runInteraction("test_select", do_select))
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index e5550aec4d..6ebd01bcbe 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -384,7 +384,7 @@ class EventAuthTestCase(unittest.TestCase):
},
)
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
authorised_join_event,
auth_events,
do_sig_check=False,
@@ -400,7 +400,7 @@ class EventAuthTestCase(unittest.TestCase):
"@inviter:foo.test"
)
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
_join_event(
pleb,
additional_content={
@@ -414,7 +414,7 @@ class EventAuthTestCase(unittest.TestCase):
# A join which is missing an authorised server is rejected.
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
_join_event(pleb),
auth_events,
do_sig_check=False,
@@ -427,7 +427,7 @@ class EventAuthTestCase(unittest.TestCase):
)
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
_join_event(
pleb,
additional_content={
@@ -442,7 +442,7 @@ class EventAuthTestCase(unittest.TestCase):
# *would* be valid, but is sent be a different user.)
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
_member_event(
pleb,
"join",
@@ -459,7 +459,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
with self.assertRaises(AuthError):
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
authorised_join_event,
auth_events,
do_sig_check=False,
@@ -468,7 +468,7 @@ class EventAuthTestCase(unittest.TestCase):
# A user who left can re-join.
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
authorised_join_event,
auth_events,
do_sig_check=False,
@@ -478,7 +478,7 @@ class EventAuthTestCase(unittest.TestCase):
# be authorised since the user is already joined.)
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
_join_event(pleb),
auth_events,
do_sig_check=False,
@@ -490,7 +490,7 @@ class EventAuthTestCase(unittest.TestCase):
pleb, "invite", sender=creator
)
event_auth.check(
- RoomVersions.MSC3083,
+ RoomVersions.V8,
_join_event(pleb),
auth_events,
do_sig_check=False,
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0ed8326f55..61c9d7c2ef 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,10 +75,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
self.handler = self.homeserver.get_federation_handler()
- self.handler._check_event_auth = (
- lambda origin, event, context, state, auth_events, backfilled: succeed(
- context
- )
+ federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
+ context
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
@@ -88,9 +87,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Send the join, it should return None (which is not an error)
self.assertEqual(
self.get_success(
- self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
- )
+ federation_event_handler.on_receive_pdu("test.serv", join_event)
),
None,
)
@@ -135,11 +132,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
+ federation_event_handler = self.homeserver.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
- self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
- ),
+ federation_event_handler.on_receive_pdu("test.serv", lying_event),
FederationError,
)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index fa6ef92b3b..66111eb367 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -17,7 +17,7 @@
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
-from synapse.rest.client.v2_alpha import register, sync
+from synapse.rest.client import register, sync
from tests import unittest
from tests.unittest import override_config
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 0df480db9f..67dcf567cd 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
-from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.rest.client.register import register_servlets
from synapse.util import Clock
from tests import unittest
diff --git a/tests/unittest.py b/tests/unittest.py
index 3eec9c4d5b..f2c90cc47b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase):
reactor=self.reactor,
)
- from tests.rest.client.v1.utils import RestHelper
+ from tests.rest.client.utils import RestHelper
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
|