diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 4003869ed6..236b608d58 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org"
- self.assertFalse((yield self.service.is_interested(self.event)))
+ self.assertFalse(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_member_is_checked(self):
@@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_id_match(self):
@@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_room_id_no_match(self):
@@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org")
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
- self.assertFalse((yield self.service.is_interested(self.event)))
+ self.assertFalse(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = [
- "#irc_foobar:matrix.org",
- "#athing:matrix.org",
- ]
- self.store.get_users_in_room.return_value = []
- self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertTrue(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org")
)
- self.store.get_aliases_for_room.return_value = [
- "#xmpp_foobar:matrix.org",
- "#athing:matrix.org",
- ]
- self.store.get_users_in_room.return_value = []
- self.assertFalse((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertFalse(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
@defer.inlineCallbacks
def test_regex_multiple_matches(self):
@@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
)
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org"
- self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
- self.store.get_users_in_room.return_value = []
- self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+ self.store.get_aliases_for_room.return_value = defer.succeed(
+ ["#irc_barfoo:matrix.org"]
+ )
+ self.store.get_users_in_room.return_value = defer.succeed([])
+ self.assertTrue(
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(self.event, self.store)
+ )
+ )
+ )
@defer.inlineCallbacks
def test_interested_in_self(self):
@@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.type = "m.room.member"
self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender
- self.assertTrue((yield self.service.is_interested(self.event)))
+ self.assertTrue(
+ (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+ )
@defer.inlineCallbacks
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
- self.store.get_users_in_room.return_value = [
- "@alice:here",
- "@irc_fo:here", # AS user
- "@bob:here",
- ]
- self.store.get_aliases_for_room.return_value = []
+ # Note that @irc_fo:here is the AS user.
+ self.store.get_users_in_room.return_value = defer.succeed(
+ ["@alice:here", "@irc_fo:here", "@bob:here"]
+ )
+ self.store.get_aliases_for_room.return_value = defer.succeed([])
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(
- (yield self.service.is_interested(event=self.event, store=self.store))
+ (
+ yield defer.ensureDeferred(
+ self.service.is_interested(event=self.event, store=self.store)
+ )
+ )
)
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 52f89d3f83..68a4caabbf 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable
from tests import unittest
+from tests.test_utils import make_awaitable
from ..utils import MockClock
@@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP)
)
- txn.send = Mock(return_value=defer.succeed(True))
+ txn.send = Mock(return_value=make_awaitable(True))
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
@@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
@@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP)
)
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
- txn.send = Mock(return_value=defer.succeed(False)) # fails to send
+ txn.send = Mock(return_value=make_awaitable(False)) # fails to send
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call
- self.txnctrl.send(service, events)
+ self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events
@@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=True)
+ txn.send = Mock(return_value=make_awaitable(True))
+ txn.complete.return_value = make_awaitable(None)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
@@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
- txn.send = Mock(return_value=False)
+ txn.send = Mock(return_value=make_awaitable(False))
+ txn.complete.return_value = make_awaitable(None)
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
@@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
- txn.send = Mock(return_value=True) # successfully send the txn
+ txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..e0ad8e8a77 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
persp_deferred = defer.Deferred()
- @defer.inlineCallbacks
- def get_perspectives(**kwargs):
+ async def get_perspectives(**kwargs):
self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
- yield persp_deferred
+ await persp_deferred
return persp_resp
self.http_client.post_json.side_effect = get_perspectives
@@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- def get_json(destination, path, **kwargs):
+ async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
@@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query
"""
- def post_json(destination, path, data, **kwargs):
+ async def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# remove the perspectives server's signature
response = build_response()
del response["signatures"][self.mock_perspective_server.server_name]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature
response = build_response()
del response["signatures"][SERVER_NAME]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 640f5f3bce..3a80626224 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase):
serialize/deserialize.
"""
- event, context = create_event(
- self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ event, context = self.get_success(
+ create_event(
+ self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ )
)
self._check_serialize_deserialize(event, context)
@@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
- event, context = create_event(
- self.hs,
- room_id=self.room_id,
- type="m.test",
- sender=self.user_id,
- state_key="",
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ state_key="",
+ )
)
self._check_serialize_deserialize(event, context)
@@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
- event, context = create_event(
- self.hs,
- room_id=self.room_id,
- type="m.room.member",
- sender=self.user_id,
- state_key=self.user_id,
- content={"membership": "leave"},
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.user_id,
+ state_key=self.user_id,
+ content={"membership": "leave"},
+ )
)
self._check_serialize_deserialize(event, context)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0c9987be54..b8ca118716 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -78,9 +79,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
- return_value=defer.succeed(("", 1))
+ return_value=make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request failed with a SynapseError saying the resource limit was
+ # exceeded.
+ f = self.get_failure(d, SynapseError)
+ self.assertEqual(f.value.code, 400, f.value)
+ self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_join_too_large_admin(self):
+ # Check whether an admin can join if option "admins_can_join" is undefined,
+ # this option defaults to false, so the join should fail.
+
+ u1 = self.register_user("u1", "pass", admin=True)
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@@ -116,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+ fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
- return_value=defer.succeed(("", 1))
+ return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
@@ -141,3 +173,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+
+class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
+ # Test the behavior of joining rooms which exceed the complexity if option
+ # limit_remote_rooms.admins_can_join is True.
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["limit_remote_rooms"] = {
+ "enabled": True,
+ "complexity": 0.05,
+ "admins_can_join": True,
+ }
+ return config
+
+ def test_join_too_large_no_admin(self):
+ # A user which is not an admin should not be able to join a remote room
+ # which is too complex.
+
+ u1 = self.register_user("u1", "pass")
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request failed with a SynapseError saying the resource limit was
+ # exceeded.
+ f = self.get_failure(d, SynapseError)
+ self.assertEqual(f.value.code, 400, f.value)
+ self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+ def test_join_too_large_admin(self):
+ # An admin should be able to join rooms where a complexity check fails.
+
+ u1 = self.register_user("u1", "pass", admin=True)
+
+ handler = self.hs.get_room_member_handler()
+ fed_transport = self.hs.get_federation_transport_client()
+
+ # Mock out some things, because we don't want to test the whole join
+ fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+ handler.federation_handler.do_invite_join = Mock(
+ return_value=make_awaitable(("", 1))
+ )
+
+ d = handler._remote_join(
+ None,
+ ["other.example.com"],
+ "roomid",
+ UserID.from_string(u1),
+ {"membership": "join"},
+ )
+
+ self.pump()
+
+ # The request success since the user is an admin
+ self.get_success(d)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 1a9bd5f37d..5f512ff8bf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -26,31 +26,34 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
+from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+ # Ensure a new Awaitable is created for each call.
+ mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ ["test", "host2"]
+ )
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]),
)
@override_config({"send_federation": True})
def test_send_receipts(self):
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
- mock_send_transaction.return_value = defer.succeed({})
+ mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -81,19 +84,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
- mock_send_transaction.return_value = defer.succeed({})
+ mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
)
- self.successResultOf(sender.send_read_receipt(receipt))
+ self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
mock_send_transaction.assert_not_called()
@@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out get_current_hosts_in_room
- mock_state_handler = hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ebabe9a7d6..628f7d8db0 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from tests.test_utils import make_awaitable
from tests.utils import MockClock
from .. import unittest
@@ -117,7 +118,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_interested_in_alias=False),
]
- self.mock_as_api.query_alias.return_value = defer.succeed(True)
+ self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
Mock(room_id=room_id, servers=servers)
@@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def _mkservice(self, is_interested):
service = Mock()
- service.is_interested.return_value = defer.succeed(is_interested)
+ service.is_interested.return_value = make_awaitable(is_interested)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
from mock import Mock
-from twisted.internet import defer
-
import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
from synapse.types import RoomAlias, create_requester
from tests import unittest
+from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f1347cd25..d70e1fc608 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_other_name(self):
- self.mock_federation.make_query.return_value = defer.succeed(
+ self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 954e059e76..69945a8f98 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+ async def resolve_service(_):
+ return result
+
+ return resolve_service
+
+
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -373,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -456,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -510,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -572,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -661,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -717,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
@@ -764,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when there is a single SRV record
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -819,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self._handle_well_known_connection(
client_factory,
@@ -861,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -922,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
+ )
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -1087,11 +1095,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
-
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"target.com", port=8443),
- Server(host=b"target.com", port=8444),
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ )
self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
-
- self.assertNoResult(resolve_d)
-
- # should have reset to the sentinel context
- self.assertIs(current_context(), SENTINEL_CONTEXT)
-
- result = yield resolve_d
+ result = yield defer.ensureDeferred(resolve_d)
# should have restored our context
self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertFalse(dns_client_mock.lookupService.called)
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
- yield resolver.resolve_service(service_name)
+ yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
lookup_deferred.callback(
(
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..ac598249e4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase):
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
- fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+ fetch_d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar")
+ )
# Nothing happened yet
self.assertNoResult(fetch_d)
@@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase):
"""
If the DNS lookup returns an error, it will bubble up.
"""
- d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ )
self.pump()
f = self.failureResultOf(d)
@@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a blacklisted IPv4 address
# ------------------------------------------------------
# Make the request
- d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
# Nothing happened yet
self.assertNoResult(d)
@@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase):
# Try making a POST request to a blacklisted IPv6 address
# -------------------------------------------------------
# Make the request
- d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ )
# Nothing has happened yet
self.assertNoResult(d)
@@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a non-blacklisted IPv4 address
# ----------------------------------------------------------
# Make the request
- d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
# Nothing has happened yet
self.assertNoResult(d)
@@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase):
request = MatrixFederationRequest(
method="GET", destination="testserv:8008", path="foo/bar"
)
- d = self.cl._send_request(request, timeout=10000)
+ d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
self.pump()
@@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase):
requiring a trailing slash. We need to retry the request with a
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase):
See test_client_requires_trailing_slashes() for context.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase):
self.failureResultOf(d)
def test_client_sends_body(self):
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+ defer.ensureDeferred(
+ self.cl.post_json(
+ "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+ )
+ )
self.pump()
@@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase):
def test_closes_connection(self):
"""Check that the client closes unused HTTP connections"""
- d = self.cl.get_json("testserv:8008", "foo/bar")
+ d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
self.pump()
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..0b5204654c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
- self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ self.get_success(
+ self.master_store.add_push_actions_to_staging(
+ event.event_id, {user_id: actions for user_id, actions in push_actions}
+ )
)
return event, context
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 097e1653b4..c9998e88e6 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
OTHER_USER = "@other_user:localhost"
# have the user join
- inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
- pl_event = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ pl_event = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
# one more bit of state that doesn't get rolled back
@@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# have the users join
for u in user_ids:
- inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_events = []
for u in user_ids:
pls["users"][u] = 0
- e = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ e = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
prev_events = [e.event_id]
pl_events.append(e)
@@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase):
body = "event %i" % (self.event_count,)
self.event_count += 1
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_event",
- content={"body": body},
- **kwargs
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_event",
+ content={"body": body},
+ **kwargs
+ )
)
def _inject_state_event(
@@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase):
if body is None:
body = "state event %s" % (state_key,)
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_state_event",
- state_key=state_key,
- content={"body": body},
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_state_event",
+ state_key=state_key,
+ content={"body": body},
+ )
)
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8d4dbf232e..83f9aa291c 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,8 +16,6 @@ import logging
from mock import Mock
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
@@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__)
@@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
- mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
@@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
- mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
- mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+ mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index b1a4decced..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.fetches = []
- def get_file(destination, path, output_stream, args=None, max_size=None):
+ async def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ return await make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 946f06d151..cec1cf928f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,1447 +1,1500 @@
-# -*- coding: utf-8 -*-
-# Copyright 2020 Dirk Klimpel
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import urllib.parse
-from typing import List, Optional
-
-from mock import Mock
-
-import synapse.rest.admin
-from synapse.api.errors import Codes
-from synapse.rest.client.v1 import directory, events, login, room
-
-from tests import unittest
-
-"""Tests admin REST events for /rooms paths."""
-
-
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "admin/shutdown_room/" + room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
-class DeleteRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_tok = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- self.room_id = self.helper.create_room_as(
- self.other_user, tok=self.other_user_tok
- )
- self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
-
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error 403 is returned.
- """
-
- request, channel = self.make_request(
- "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_room_does_not_exist(self):
- """
- Check that unknown rooms/server return error 404.
- """
- url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
-
- request, channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_room_is_not_valid(self):
- """
- Check that invalid room names, return an error 400.
- """
- url = "/_synapse/admin/v1/rooms/invalidroom/delete"
-
- request, channel = self.make_request(
- "POST", url, json.dumps({}), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "invalidroom is not a legal room ID", channel.json_body["error"],
- )
-
- def test_new_room_user_does_not_exist(self):
- """
- Tests that the user ID must be from local server but it does not have to exist.
- """
- body = json.dumps({"new_room_user_id": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertIn("new_room_id", channel.json_body)
- self.assertIn("kicked_users", channel.json_body)
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- def test_new_room_user_is_not_local(self):
- """
- Check that only local users can create new room to move members.
- """
- body = json.dumps({"new_room_user_id": "@not:exist.bla"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "User must be our own: @not:exist.bla", channel.json_body["error"],
- )
-
- def test_block_is_not_bool(self):
- """
- If parameter `block` is not boolean, return an error
- """
- body = json.dumps({"block": "NotBool"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
-
- def test_purge_room_and_block(self):
- """Test to purge a room and block it.
- Members will not be moved to a new room and will not receive a message.
- """
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Test that room is not blocked
- self._is_blocked(self.room_id, expect=False)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- body = json.dumps({"block": True})
-
- request, channel = self.make_request(
- "POST",
- self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(None, channel.json_body["new_room_id"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- self._is_purged(self.room_id)
- self._is_blocked(self.room_id, expect=True)
- self._has_no_members(self.room_id)
-
- def test_purge_room_and_not_block(self):
- """Test to purge a room and do not block it.
- Members will not be moved to a new room and will not receive a message.
- """
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Test that room is not blocked
- self._is_blocked(self.room_id, expect=False)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- body = json.dumps({"block": False})
-
- request, channel = self.make_request(
- "POST",
- self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(None, channel.json_body["new_room_id"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- self._is_purged(self.room_id)
- self._is_blocked(self.room_id, expect=False)
- self._has_no_members(self.room_id)
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- Members will be moved to a new room and will receive a message.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
- )
-
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("new_room_id", channel.json_body)
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- # Test that member has moved to new room
- self._is_member(
- room_id=channel.json_body["new_room_id"], user_id=self.other_user
- )
-
- self._is_purged(self.room_id)
- self._has_no_members(self.room_id)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- Members will be moved to a new room and will receive a message.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that room is not purged
- with self.assertRaises(AssertionError):
- self._is_purged(self.room_id)
-
- # Assert one user in room
- self._is_member(room_id=self.room_id, user_id=self.other_user)
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
- self.assertIn("new_room_id", channel.json_body)
- self.assertIn("failed_to_kick_users", channel.json_body)
- self.assertIn("local_aliases", channel.json_body)
-
- # Test that member has moved to new room
- self._is_member(
- room_id=channel.json_body["new_room_id"], user_id=self.other_user
- )
-
- self._is_purged(self.room_id)
- self._has_no_members(self.room_id)
-
- # Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=403)
-
- def _is_blocked(self, room_id, expect=True):
- """Assert that the room is blocked or not
- """
- d = self.store.is_room_blocked(room_id)
- if expect:
- self.assertTrue(self.get_success(d))
- else:
- self.assertIsNone(self.get_success(d))
-
- def _has_no_members(self, room_id):
- """Assert there is now no longer anyone in the room
- """
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def _is_member(self, room_id, user_id):
- """Test that user is member of the room
- """
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertIn(user_id, users_in_room)
-
- def _is_purged(self, room_id):
- """Test that the following tables have been purged of all rows related to the room.
- """
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room.
- """
-
- url = "rooms/%s/initialSync" % (room_id,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.render(request)
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- request, channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "room_account_data",
- "room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
- "state_groups_state",
- ):
- count = self.get_success(
- self.store.db.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
-
-
-class RoomTestCase(unittest.HomeserverTestCase):
- """Test /room admin API.
- """
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- directory.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- # Create user
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_list_rooms(self):
- """Test that we can list rooms"""
- # Create 3 test rooms
- total_rooms = 3
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
-
- # Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
-
- # Check that response json body contains a "rooms" key
- self.assertTrue(
- "rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
- )
-
- # Check that 3 rooms were returned
- self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
-
- # Check their room_ids match
- returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
- self.assertEqual(room_ids, returned_room_ids)
-
- # Check that all fields are available
- for r in channel.json_body["rooms"]:
- self.assertIn("name", r)
- self.assertIn("canonical_alias", r)
- self.assertIn("joined_members", r)
- self.assertIn("joined_local_members", r)
- self.assertIn("version", r)
- self.assertIn("creator", r)
- self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
- self.assertIn("join_rules", r)
- self.assertIn("guest_access", r)
- self.assertIn("history_visibility", r)
- self.assertIn("state_events", r)
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # Should be 0 as we aren't paginating
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that the prev_batch parameter is not present
- self.assertNotIn("prev_batch", channel.json_body)
-
- # We shouldn't receive a next token here as there's no further rooms to show
- self.assertNotIn("next_batch", channel.json_body)
-
- def test_list_rooms_pagination(self):
- """Test that we can get a full list of rooms through pagination"""
- # Create 5 test rooms
- total_rooms = 5
- room_ids = []
- for x in range(total_rooms):
- room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
- )
- room_ids.append(room_id)
-
- # Set the name of the rooms so we get a consistent returned ordering
- for idx, room_id in enumerate(room_ids):
- self.helper.send_state(
- room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- returned_room_ids = []
- start = 0
- limit = 2
-
- run_count = 0
- should_repeat = True
- while should_repeat:
- run_count += 1
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
- start,
- limit,
- "name",
- )
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- self.assertTrue("rooms" in channel.json_body)
- for r in channel.json_body["rooms"]:
- returned_room_ids.append(r["room_id"])
-
- # Check that the correct number of total rooms was returned
- self.assertEqual(channel.json_body["total_rooms"], total_rooms)
-
- # Check that the offset is correct
- # We're only getting 2 rooms each page, so should be 2 * last run_count
- self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
-
- if run_count > 1:
- # Check the value of prev_batch is correct
- self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
-
- if "next_batch" not in channel.json_body:
- # We have reached the end of the list
- should_repeat = False
- else:
- # Make another query with an updated start value
- start = channel.json_body["next_batch"]
-
- # We should've queried the endpoint 3 times
- self.assertEqual(
- run_count,
- 3,
- msg="Should've queried 3 times for 5 rooms with limit 2 per query",
- )
-
- # Check that we received all of the room ids
- self.assertEqual(room_ids, returned_room_ids)
-
- url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- def test_correct_room_attributes(self):
- """Test the correct attributes for a room are returned"""
- # Create a test room
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- test_alias = "#test:test"
- test_room_name = "something"
-
- # Have another user join the room
- user_2 = self.register_user("user4", "pass")
- user_tok_2 = self.login("user4", "pass")
- self.helper.join(room_id, user_2, tok=user_tok_2)
-
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=self.admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=self.admin_user_tok,
- )
-
- # Set a name for the room
- self.helper.send_state(
- room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
- )
-
- # Request the list of rooms
- url = "/_synapse/admin/v1/rooms"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that only one room was returned
- self.assertEqual(len(rooms), 1)
-
- # And that the value of the total_rooms key was correct
- self.assertEqual(channel.json_body["total_rooms"], 1)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that all provided attributes are set
- r = rooms[0]
- self.assertEqual(room_id, r["room_id"])
- self.assertEqual(test_room_name, r["name"])
- self.assertEqual(test_alias, r["canonical_alias"])
-
- def test_room_list_sort_order(self):
- """Test room list sort ordering. alphabetical name versus number of members,
- reversing the order, etc.
- """
-
- def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
- # Create a new alias to this room
- url = "/_matrix/client/r0/directory/room/%s" % (
- urllib.parse.quote(test_alias),
- )
- request, channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=admin_user_tok,
- )
- self.render(request)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- # Set this new alias as the canonical alias for this room
- self.helper.send_state(
- room_id,
- "m.room.aliases",
- {"aliases": [test_alias]},
- tok=admin_user_tok,
- state_key="test",
- )
- self.helper.send_state(
- room_id,
- "m.room.canonical_alias",
- {"alias": test_alias},
- tok=admin_user_tok,
- )
-
- def _order_test(
- order_type: str, expected_room_list: List[str], reverse: bool = False,
- ):
- """Request the list of rooms in a certain order. Assert that order is what
- we expect
-
- Args:
- order_type: The type of ordering to give the server
- expected_room_list: The list of room_ids in the order we expect to get
- back from the server
- """
- # Request the list of rooms in the given order
- url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
- if reverse:
- url += "&dir=b"
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check for the correct total_rooms value
- self.assertEqual(channel.json_body["total_rooms"], 3)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- # Check that rooms were returned in alphabetical order
- returned_order = [r["room_id"] for r in rooms]
- self.assertListEqual(expected_room_list, returned_order) # order is checked
-
- # Create 3 test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
- )
-
- # Set room canonical room aliases
- _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
- _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
-
- # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
- user_1 = self.register_user("bob1", "pass")
- user_1_tok = self.login("bob1", "pass")
- self.helper.join(room_id_2, user_1, tok=user_1_tok)
-
- user_2 = self.register_user("bob2", "pass")
- user_2_tok = self.login("bob2", "pass")
- self.helper.join(room_id_3, user_2, tok=user_2_tok)
-
- user_3 = self.register_user("bob3", "pass")
- user_3_tok = self.login("bob3", "pass")
- self.helper.join(room_id_3, user_3, tok=user_3_tok)
-
- # Test different sort orders, with forward and reverse directions
- _order_test("name", [room_id_1, room_id_2, room_id_3])
- _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
- _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
- _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
- _order_test(
- "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("version", [room_id_1, room_id_2, room_id_3])
- _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("creator", [room_id_1, room_id_2, room_id_3])
- _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("encryption", [room_id_1, room_id_2, room_id_3])
- _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("federatable", [room_id_1, room_id_2, room_id_3])
- _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("public", [room_id_1, room_id_2, room_id_3])
- # Different sort order of SQlite and PostreSQL
- # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
-
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
- _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
- _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
- _order_test(
- "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
- )
-
- _order_test("state_events", [room_id_3, room_id_2, room_id_1])
- _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
-
- def test_search_term(self):
- """Test that searching for a room works correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- def _search_test(
- expected_room_id: Optional[str],
- search_term: str,
- expected_http_code: int = 200,
- ):
- """Search for a room and check that the returned room's id is a match
-
- Args:
- expected_room_id: The room_id expected to be returned by the API. Set
- to None to expect zero results for the search
- search_term: The term to search for room names with
- expected_http_code: The expected http code for the request
- """
- url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
-
- if expected_http_code != 200:
- return
-
- # Check that rooms were returned
- self.assertTrue("rooms" in channel.json_body)
- rooms = channel.json_body["rooms"]
-
- # Check that the expected number of rooms were returned
- expected_room_count = 1 if expected_room_id else 0
- self.assertEqual(len(rooms), expected_room_count)
- self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
-
- # Check that the offset is correct
- # We're not paginating, so should be 0
- self.assertEqual(channel.json_body["offset"], 0)
-
- # Check that there is no `prev_batch`
- self.assertNotIn("prev_batch", channel.json_body)
-
- # Check that there is no `next_batch`
- self.assertNotIn("next_batch", channel.json_body)
-
- if expected_room_id:
- # Check that the first returned room id is correct
- r = rooms[0]
- self.assertEqual(expected_room_id, r["room_id"])
-
- # Perform search tests
- _search_test(room_id_1, "something")
- _search_test(room_id_1, "thing")
-
- _search_test(room_id_2, "else")
- _search_test(room_id_2, "se")
-
- _search_test(None, "foo")
- _search_test(None, "bar")
- _search_test(None, "", expected_http_code=400)
-
- def test_single_room(self):
- """Test that a single room can be requested correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- room_name_1 = "something"
- room_name_2 = "else"
-
- # Set the name for each room
- self.helper.send_state(
- room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
- )
- self.helper.send_state(
- room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
- )
-
- url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertIn("room_id", channel.json_body)
- self.assertIn("name", channel.json_body)
- self.assertIn("canonical_alias", channel.json_body)
- self.assertIn("joined_members", channel.json_body)
- self.assertIn("joined_local_members", channel.json_body)
- self.assertIn("version", channel.json_body)
- self.assertIn("creator", channel.json_body)
- self.assertIn("encryption", channel.json_body)
- self.assertIn("federatable", channel.json_body)
- self.assertIn("public", channel.json_body)
- self.assertIn("join_rules", channel.json_body)
- self.assertIn("guest_access", channel.json_body)
- self.assertIn("history_visibility", channel.json_body)
- self.assertIn("state_events", channel.json_body)
-
- self.assertEqual(room_id_1, channel.json_body["room_id"])
-
- def test_room_members(self):
- """Test that room members can be requested correctly"""
- # Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # Have another user join the room
- user_1 = self.register_user("foo", "pass")
- user_tok_1 = self.login("foo", "pass")
- self.helper.join(room_id_1, user_1, tok=user_tok_1)
-
- # Have another user join the room
- user_2 = self.register_user("bar", "pass")
- user_tok_2 = self.login("bar", "pass")
- self.helper.join(room_id_1, user_2, tok=user_tok_2)
- self.helper.join(room_id_2, user_2, tok=user_tok_2)
-
- # Have another user join the room
- user_3 = self.register_user("foobar", "pass")
- user_tok_3 = self.login("foobar", "pass")
- self.helper.join(room_id_2, user_3, tok=user_tok_3)
-
- url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertCountEqual(
- ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
- )
- self.assertEqual(channel.json_body["total"], 3)
-
- url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
- request, channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, channel.code, msg=channel.json_body)
-
- self.assertCountEqual(
- ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
- )
- self.assertEqual(channel.json_body["total"], 3)
-
-
-class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
- servlets = [
- synapse.rest.admin.register_servlets,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor, clock, homeserver):
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.creator = self.register_user("creator", "test")
- self.creator_tok = self.login("creator", "test")
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
-
- self.public_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=True
- )
- self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
-
- def test_requester_is_no_admin(self):
- """
- If the user is not a server admin, an error 403 is returned.
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.second_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_invalid_parameter(self):
- """
- If a parameter is missing, return an error
- """
- body = json.dumps({"unknown_parameter": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- def test_local_user_does_not_exist(self):
- """
- Tests that a lookup for a user that does not exist returns a 404
- """
- body = json.dumps({"user_id": "@unknown:test"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_remote_user(self):
- """
- Check that only local user can join rooms.
- """
- body = json.dumps({"user_id": "@not:exist.bla"})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "This endpoint can only be used with local users",
- channel.json_body["error"],
- )
-
- def test_room_does_not_exist(self):
- """
- Check that unknown rooms/server return error 404.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/!unknown:test"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("No known servers", channel.json_body["error"])
-
- def test_room_is_not_valid(self):
- """
- Check that invalid room names, return an error 400.
- """
- body = json.dumps({"user_id": self.second_user_id})
- url = "/_synapse/admin/v1/join/invalidroom"
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(
- "invalidroom was not legal room ID or room alias",
- channel.json_body["error"],
- )
-
- def test_join_public_room(self):
- """
- Test joining a local user to a public room with "JoinRules.PUBLIC"
- """
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- self.url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_not_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE"
- when server admin is not member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_join_private_room_if_member(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is member of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.creator, tok=self.creator_tok, is_public=False
- )
- self.helper.invite(
- room=private_room_id,
- src=self.creator,
- targ=self.admin_user,
- tok=self.creator_tok,
- )
- self.helper.join(
- room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
- )
-
- # Validate if server admin is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- # Join user to room.
-
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
-
- def test_join_private_room_if_owner(self):
- """
- Test joining a local user to a private room with "JoinRules.INVITE",
- when server admin is owner of this room.
- """
- private_room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok, is_public=False
- )
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
- body = json.dumps({"user_id": self.second_user_id})
-
- request, channel = self.make_request(
- "POST",
- url,
- content=body.encode(encoding="utf_8"),
- access_token=self.admin_user_tok,
- )
- self.render(request)
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["room_id"])
-
- # Validate if user is a member of the room
-
- request, channel = self.make_request(
- "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
- )
- self.render(request)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ room_id, body="foo", tok=self.other_user_token, expect_code=403
+ )
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert there is now no longer anyone in the room
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ """
+
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_token,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the admin can still send shutdown
+ url = "admin/shutdown_room/" + room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(room_id, expect_code=403)
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class DeleteRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ url = "/_synapse/admin/v1/rooms/invalidroom/delete"
+
+ request, channel = self.make_request(
+ "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom is not a legal room ID", channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+ body = json.dumps({"new_room_user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("kicked_users", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+ body = json.dumps({"new_room_user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "User must be our own: @not:exist.bla", channel.json_body["error"],
+ )
+
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+ body = json.dumps({"block": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_is_not_bool(self):
+ """
+ If parameter `purge` is not boolean, return an error
+ """
+ body = json.dumps({"purge": "NotBool"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": True, "purge": True})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False, "purge": True})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_block_room_and_not_purge(self):
+ """Test to block a room without purging it.
+ Members will not be moved to a new room and will not receive a message.
+ The room will not be purged.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ body = json.dumps({"block": False, "purge": False})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url.encode("ascii"),
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(None, channel.json_body["new_room_id"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ json.dumps({"history_visibility": "world_readable"}),
+ access_token=self.other_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ json.dumps({"new_room_user_id": self.admin_user}),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+ self.assertIn("new_room_id", channel.json_body)
+ self.assertIn("failed_to_kick_users", channel.json_body)
+ self.assertIn("local_aliases", channel.json_body)
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["new_room_id"], user_id=self.other_user
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id, expect=True):
+ """Assert that the room is blocked or not
+ """
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id):
+ """Assert there is now no longer anyone in the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id, user_id):
+ """Test that user is member of the room
+ """
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id):
+ """Test that the following tables have been purged of all rows related to the room.
+ """
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+ def _assert_peek(self, room_id, expect_code):
+ """Assert that the admin user can (or cannot) peek into the room.
+ """
+
+ url = "rooms/%s/initialSync" % (room_id,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ url = "events?timeout=0&room_id=" + room_id
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.render(request)
+ self.assertEqual(
+ expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "room_account_data",
+ "room_tags",
+ # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups_state",
+ ):
+ count = self.get_success(
+ self.store.db.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+ """Test /room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_list_rooms(self):
+ """Test that we can list rooms"""
+ # Create 3 test rooms
+ total_rooms = 3
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ # Check request completed successfully
+ self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+ # Check that response json body contains a "rooms" key
+ self.assertTrue(
+ "rooms" in channel.json_body,
+ msg="Response body does not " "contain a 'rooms' key",
+ )
+
+ # Check that 3 rooms were returned
+ self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+ # Check their room_ids match
+ returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+ self.assertEqual(room_ids, returned_room_ids)
+
+ # Check that all fields are available
+ for r in channel.json_body["rooms"]:
+ self.assertIn("name", r)
+ self.assertIn("canonical_alias", r)
+ self.assertIn("joined_members", r)
+ self.assertIn("joined_local_members", r)
+ self.assertIn("version", r)
+ self.assertIn("creator", r)
+ self.assertIn("encryption", r)
+ self.assertIn("federatable", r)
+ self.assertIn("public", r)
+ self.assertIn("join_rules", r)
+ self.assertIn("guest_access", r)
+ self.assertIn("history_visibility", r)
+ self.assertIn("state_events", r)
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # Should be 0 as we aren't paginating
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that the prev_batch parameter is not present
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # We shouldn't receive a next token here as there's no further rooms to show
+ self.assertNotIn("next_batch", channel.json_body)
+
+ def test_list_rooms_pagination(self):
+ """Test that we can get a full list of rooms through pagination"""
+ # Create 5 test rooms
+ total_rooms = 5
+ room_ids = []
+ for x in range(total_rooms):
+ room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+ room_ids.append(room_id)
+
+ # Set the name of the rooms so we get a consistent returned ordering
+ for idx, room_id in enumerate(room_ids):
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ returned_room_ids = []
+ start = 0
+ limit = 2
+
+ run_count = 0
+ should_repeat = True
+ while should_repeat:
+ run_count += 1
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+ start,
+ limit,
+ "name",
+ )
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ self.assertTrue("rooms" in channel.json_body)
+ for r in channel.json_body["rooms"]:
+ returned_room_ids.append(r["room_id"])
+
+ # Check that the correct number of total rooms was returned
+ self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+ # Check that the offset is correct
+ # We're only getting 2 rooms each page, so should be 2 * last run_count
+ self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+ if run_count > 1:
+ # Check the value of prev_batch is correct
+ self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+ if "next_batch" not in channel.json_body:
+ # We have reached the end of the list
+ should_repeat = False
+ else:
+ # Make another query with an updated start value
+ start = channel.json_body["next_batch"]
+
+ # We should've queried the endpoint 3 times
+ self.assertEqual(
+ run_count,
+ 3,
+ msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+ )
+
+ # Check that we received all of the room ids
+ self.assertEqual(room_ids, returned_room_ids)
+
+ url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_correct_room_attributes(self):
+ """Test the correct attributes for a room are returned"""
+ # Create a test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ test_alias = "#test:test"
+ test_room_name = "something"
+
+ # Have another user join the room
+ user_2 = self.register_user("user4", "pass")
+ user_tok_2 = self.login("user4", "pass")
+ self.helper.join(room_id, user_2, tok=user_tok_2)
+
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=self.admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=self.admin_user_tok,
+ )
+
+ # Set a name for the room
+ self.helper.send_state(
+ room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+ )
+
+ # Request the list of rooms
+ url = "/_synapse/admin/v1/rooms"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that only one room was returned
+ self.assertEqual(len(rooms), 1)
+
+ # And that the value of the total_rooms key was correct
+ self.assertEqual(channel.json_body["total_rooms"], 1)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that all provided attributes are set
+ r = rooms[0]
+ self.assertEqual(room_id, r["room_id"])
+ self.assertEqual(test_room_name, r["name"])
+ self.assertEqual(test_alias, r["canonical_alias"])
+
+ def test_room_list_sort_order(self):
+ """Test room list sort ordering. alphabetical name versus number of members,
+ reversing the order, etc.
+ """
+
+ def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+ # Create a new alias to this room
+ url = "/_matrix/client/r0/directory/room/%s" % (
+ urllib.parse.quote(test_alias),
+ )
+ request, channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(
+ 200, int(channel.result["code"]), msg=channel.result["body"]
+ )
+
+ # Set this new alias as the canonical alias for this room
+ self.helper.send_state(
+ room_id,
+ "m.room.aliases",
+ {"aliases": [test_alias]},
+ tok=admin_user_tok,
+ state_key="test",
+ )
+ self.helper.send_state(
+ room_id,
+ "m.room.canonical_alias",
+ {"alias": test_alias},
+ tok=admin_user_tok,
+ )
+
+ def _order_test(
+ order_type: str, expected_room_list: List[str], reverse: bool = False,
+ ):
+ """Request the list of rooms in a certain order. Assert that order is what
+ we expect
+
+ Args:
+ order_type: The type of ordering to give the server
+ expected_room_list: The list of room_ids in the order we expect to get
+ back from the server
+ """
+ # Request the list of rooms in the given order
+ url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+ if reverse:
+ url += "&dir=b"
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check for the correct total_rooms value
+ self.assertEqual(channel.json_body["total_rooms"], 3)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ # Check that rooms were returned in alphabetical order
+ returned_order = [r["room_id"] for r in rooms]
+ self.assertListEqual(expected_room_list, returned_order) # order is checked
+
+ # Create 3 test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+ )
+
+ # Set room canonical room aliases
+ _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+ _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+ # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+ user_1 = self.register_user("bob1", "pass")
+ user_1_tok = self.login("bob1", "pass")
+ self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+ user_2 = self.register_user("bob2", "pass")
+ user_2_tok = self.login("bob2", "pass")
+ self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+ user_3 = self.register_user("bob3", "pass")
+ user_3_tok = self.login("bob3", "pass")
+ self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+ # Test different sort orders, with forward and reverse directions
+ _order_test("name", [room_id_1, room_id_2, room_id_3])
+ _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+ _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+ _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+ _order_test(
+ "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("version", [room_id_1, room_id_2, room_id_3])
+ _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("creator", [room_id_1, room_id_2, room_id_3])
+ _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+ _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+ _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("public", [room_id_1, room_id_2, room_id_3])
+ # Different sort order of SQlite and PostreSQL
+ # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+ _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+ _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+ _order_test(
+ "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+ )
+
+ _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+ _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+ def test_search_term(self):
+ """Test that searching for a room works correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ def _search_test(
+ expected_room_id: Optional[str],
+ search_term: str,
+ expected_http_code: int = 200,
+ ):
+ """Search for a room and check that the returned room's id is a match
+
+ Args:
+ expected_room_id: The room_id expected to be returned by the API. Set
+ to None to expect zero results for the search
+ search_term: The term to search for room names with
+ expected_http_code: The expected http code for the request
+ """
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+ if expected_http_code != 200:
+ return
+
+ # Check that rooms were returned
+ self.assertTrue("rooms" in channel.json_body)
+ rooms = channel.json_body["rooms"]
+
+ # Check that the expected number of rooms were returned
+ expected_room_count = 1 if expected_room_id else 0
+ self.assertEqual(len(rooms), expected_room_count)
+ self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+ # Check that the offset is correct
+ # We're not paginating, so should be 0
+ self.assertEqual(channel.json_body["offset"], 0)
+
+ # Check that there is no `prev_batch`
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ # Check that there is no `next_batch`
+ self.assertNotIn("next_batch", channel.json_body)
+
+ if expected_room_id:
+ # Check that the first returned room id is correct
+ r = rooms[0]
+ self.assertEqual(expected_room_id, r["room_id"])
+
+ # Perform search tests
+ _search_test(room_id_1, "something")
+ _search_test(room_id_1, "thing")
+
+ _search_test(room_id_2, "else")
+ _search_test(room_id_2, "se")
+
+ _search_test(None, "foo")
+ _search_test(None, "bar")
+ _search_test(None, "", expected_http_code=400)
+
+ def test_single_room(self):
+ """Test that a single room can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ room_name_1 = "something"
+ room_name_2 = "else"
+
+ # Set the name for each room
+ self.helper.send_state(
+ room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+ )
+ self.helper.send_state(
+ room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+ )
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertIn("room_id", channel.json_body)
+ self.assertIn("name", channel.json_body)
+ self.assertIn("canonical_alias", channel.json_body)
+ self.assertIn("joined_members", channel.json_body)
+ self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("version", channel.json_body)
+ self.assertIn("creator", channel.json_body)
+ self.assertIn("encryption", channel.json_body)
+ self.assertIn("federatable", channel.json_body)
+ self.assertIn("public", channel.json_body)
+ self.assertIn("join_rules", channel.json_body)
+ self.assertIn("guest_access", channel.json_body)
+ self.assertIn("history_visibility", channel.json_body)
+ self.assertIn("state_events", channel.json_body)
+
+ self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+ def test_room_members(self):
+ """Test that room members can be requested correctly"""
+ # Create two test rooms
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ # Have another user join the room
+ user_2 = self.register_user("bar", "pass")
+ user_tok_2 = self.login("bar", "pass")
+ self.helper.join(room_id_1, user_2, tok=user_tok_2)
+ self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+ # Have another user join the room
+ user_3 = self.register_user("foobar", "pass")
+ user_tok_3 = self.login("foobar", "pass")
+ self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+ url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ self.assertCountEqual(
+ ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+ )
+ self.assertEqual(channel.json_body["total"], 3)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.creator = self.register_user("creator", "test")
+ self.creator_tok = self.login("creator", "test")
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+
+ self.public_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=True
+ )
+ self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.second_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If a parameter is missing, return an error
+ """
+ body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ def test_local_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ body = json.dumps({"user_id": "@unknown:test"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_remote_user(self):
+ """
+ Check that only local user can join rooms.
+ """
+ body = json.dumps({"user_id": "@not:exist.bla"})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "This endpoint can only be used with local users",
+ channel.json_body["error"],
+ )
+
+ def test_room_does_not_exist(self):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/!unknown:test"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("No known servers", channel.json_body["error"])
+
+ def test_room_is_not_valid(self):
+ """
+ Check that invalid room names, return an error 400.
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+ url = "/_synapse/admin/v1/join/invalidroom"
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ "invalidroom was not legal room ID or room alias",
+ channel.json_body["error"],
+ )
+
+ def test_join_public_room(self):
+ """
+ Test joining a local user to a public room with "JoinRules.PUBLIC"
+ """
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ self.url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_not_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE"
+ when server admin is not member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_join_private_room_if_member(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is member of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.creator, tok=self.creator_tok, is_public=False
+ )
+ self.helper.invite(
+ room=private_room_id,
+ src=self.creator,
+ targ=self.admin_user,
+ tok=self.creator_tok,
+ )
+ self.helper.join(
+ room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+ )
+
+ # Validate if server admin is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ # Join user to room.
+
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+ def test_join_private_room_if_owner(self):
+ """
+ Test joining a local user to a private room with "JoinRules.INVITE",
+ when server admin is owner of this room.
+ """
+ private_room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
+ url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ body = json.dumps({"user_id": self.second_user_id})
+
+ request, channel = self.make_request(
+ "POST",
+ url,
+ content=body.encode(encoding="utf_8"),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+ # Validate if user is a member of the room
+
+ request, channel = self.make_request(
+ "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+ )
+ self.render(request)
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..7f8252330a 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -143,6 +143,26 @@ class RestHelper(object):
return channel.json_body
+ def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
+
def _read_write_state(
self,
room_id: str,
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
- def get_json(destination, path, ignore_backoff=False, **kwargs):
+ async def get_json(destination, path, ignore_backoff=False, **kwargs):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
- def post_json(destination, path, data):
+ async def post_json(destination, path, data):
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path, "/_matrix/key/v2/query",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 66fa5978b2..f4f3e56777 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -26,6 +26,7 @@ import attr
from parameterized import parameterized_class
from PIL import Image as Image
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
@@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
- x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+ x = defer.ensureDeferred(
+ self.media_storage.ensure_media_is_in_local_cache(file_info)
+ )
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 2826211f32..74765a582b 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,8 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import os
+import re
+
+from mock import patch
import attr
@@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
def test_cache_returns_correct_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://matrix.org", shorthand=False
@@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def test_non_ascii_preview_httpequiv(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_non_ascii_preview_content_type(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self):
- self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
b"<html><head>"
@@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
IP addresses can be previewed directly.
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
"GET", "url_preview?url=http://example.com", shorthand=False
@@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Hardcode the URL resolving to the IP we want.
self.lookups["example.com"] = [
(IPv4Address, "1.1.1.2"),
- (IPv4Address, "8.8.8.8"),
+ (IPv4Address, "10.1.2.3"),
]
request, channel = self.make_request(
@@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""
Accept-Language header is sent to the remote server
"""
- self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+ self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
# Build and make a request to the server
request, channel = self.make_request(
@@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase):
),
server.data,
)
+
+ def test_oembed_photo(self):
+ """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
+
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(oembed_content),)
+ + oembed_content
+ )
+
+ self.pump()
+
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
+
+ def test_oembed_rich(self):
+ """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+ # Route the HTTP version to an HTTP endpoint so that the tests work.
+ with patch.dict(
+ "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+ {
+ re.compile(
+ r"http://twitter\.com/.+/status/.+"
+ ): "http://publish.twitter.com/oembed",
+ },
+ clear=True,
+ ):
+
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ request, channel = self.make_request(
+ "GET",
+ "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ )
+ request.render(self.preview_url)
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+ )
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 38f9b423ef..f2955a9c69 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
+from typing import List
import attr
@@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- state_before = self.successResultOf(state_d)
+ state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
- state = self.successResultOf(state_d)
+ state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ return defer.succeed(
+ {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ )
- def _get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
+ event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
- return set(chains[0]).union(*chains[1:]) - common
+ return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..2b1580feeb 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_http(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield self.store.get_unread_push_actions_for_user_in_range_for_email(
- USER_ID, 0, 1000, 20
+ yield defer.ensureDeferred(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ USER_ID, 0, 1000, 20
+ )
)
@defer.inlineCallbacks
@@ -72,8 +76,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ yield defer.ensureDeferred(
+ self.store.add_push_actions_to_staging(
+ event.event_id, {user_id: action}
+ )
)
yield self.store.db.runInteraction(
"",
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..a6012c973d 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
event = self.successResultOf(event)
# Purge everything before this topological token
- purge = storage.purge_events.purge_history(self.room_id, event, True)
+ purge = defer.ensureDeferred(
+ storage.purge_events.purge_history(self.room_id, event, True)
+ )
self.pump()
self.assertEqual(self.successResultOf(purge), None)
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
)
# Purge everything before this topological token
- purge = storage.purge_history(self.room_id, event, True)
+ purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..0f0e1cd09b 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
- built_event = yield self._base_builder.build(prev_event_ids)
+ built_event = yield defer.ensureDeferred(
+ self._base_builder.build(prev_event_ids)
+ )
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index b1dceb2918..d07b985a8e 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id=self.u_creator.to_string(),
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id=self.u_creator.to_string(),
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -88,17 +90,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.storage.persistence.persist_event(
- self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(
+ self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+ )
)
@defer.inlineCallbacks
@@ -109,7 +115,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
@@ -125,7 +133,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005e6..f282921538 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test_get_joined_users_from_context(self):
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- bob_event = event_injection.inject_member_event(
- self.hs, room, self.u_bob, Membership.JOIN
+ bob_event = self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room, self.u_bob, Membership.JOIN
+ )
)
# first, create a regular event
- event, context = event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[bob_event.event_id],
- type="m.test.1",
- content={},
+ event, context = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.1",
+ content={},
+ )
)
users = self.get_success(
@@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Regression test for #7376: create a state event whose key matches bob's
# user_id, but which is *not* a membership event, and persist that; then check
# that `get_joined_users_from_context` returns the correct users for the next event.
- non_member_event = event_injection.inject_event(
- self.hs,
- room_id=room,
- sender=self.u_bob,
- prev_event_ids=[bob_event.event_id],
- type="m.test.2",
- state_key=self.u_bob,
- content={},
+ non_member_event = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_bob,
+ prev_event_ids=[bob_event.event_id],
+ type="m.test.2",
+ state_key=self.u_bob,
+ content={},
+ )
)
- event, context = event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[non_member_event.event_id],
- type="m.test.3",
- content={},
+ event, context = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room,
+ sender=self.u_alice,
+ prev_event_ids=[non_member_event.event_id],
+ type="m.test.3",
+ content={},
+ )
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308ff4..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
+ yield defer.ensureDeferred(
+ self.store.store_room(
+ self.room.to_string(),
+ room_creator_user_id="@creator:text",
+ is_public=True,
+ room_version=RoomVersions.V1,
+ )
)
@defer.inlineCallbacks
@@ -64,11 +66,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups_ids(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.storage.state.get_state_groups(
- self.room, [e2.event_id]
+ state_group_map = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.storage.state.get_state_for_event(e5.event_id)
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(e5.event_id)
+ )
self.assertIsNotNone(e4)
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+ )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.storage.state.get_state_for_event(
- e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+ )
)
self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: {self.u_alice.to_string()}},
- include_others=True,
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ ),
+ )
)
self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.storage.state.get_state_for_event(
- e5.event_id,
- state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
- ),
+ state = yield defer.ensureDeferred(
+ self.storage.state.get_state_for_event(
+ e5.event_id,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()}, include_others=True
+ ),
+ )
)
self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.storage.state.get_state_groups_ids(
- room_id, [e5.event_id]
+ group_ids = yield defer.ensureDeferred(
+ self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 87a16d7d7a..c2f12c2741 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
prev_events that said event references.
"""
- def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(destination, path, data, headers=None, timeout=0):
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
diff --git a/tests/test_server.py b/tests/test_server.py
index 030f58cbdc..073b2362cc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -12,26 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import re
-from io import StringIO
from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite, logger
+from synapse.http.site import SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
from tests import unittest
from tests.server import (
- FakeTransport,
ThreadedMemoryReactorClock,
make_request,
render,
@@ -199,10 +193,10 @@ class OptionsResourceTests(unittest.TestCase):
return channel
def test_unknown_options_request(self):
- """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ """An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -219,10 +213,10 @@ class OptionsResourceTests(unittest.TestCase):
)
def test_known_options_request(self):
- """An OPTIONS requests to an known URL still returns 200 OK."""
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -318,54 +312,3 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
-
-
-class SiteTestCase(unittest.HomeserverTestCase):
- def test_lose_connection(self):
- """
- We log the URI correctly redacted when we lose the connection.
- """
-
- class HangingResource(Resource):
- """
- A Resource that strategically hangs, as if it were processing an
- answer.
- """
-
- def render(self, request):
- return NOT_DONE_YET
-
- # Set up a logging handler that we can inspect afterwards
- output = StringIO()
- handler = logging.StreamHandler(output)
- logger.addHandler(handler)
- old_level = logger.level
- logger.setLevel(10)
- self.addCleanup(logger.setLevel, old_level)
- self.addCleanup(logger.removeHandler, handler)
-
- # Make a resource and a Site, the resource will hang and allow us to
- # time out the request while it's 'processing'
- base_resource = Resource()
- base_resource.putChild(b"", HangingResource())
- site = SynapseSite(
- "test", "site_tag", self.hs.config.listeners[0], base_resource, "1.0"
- )
-
- server = site.buildProtocol(None)
- client = AccumulatingProtocol()
- client.makeConnection(FakeTransport(server, self.reactor))
- server.makeConnection(FakeTransport(client, self.reactor))
-
- # Send a request with an access token that will get redacted
- server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n")
- self.pump()
-
- # Lose the connection
- e = Failure(Exception("Failed123"))
- server.connectionLost(e)
- handler.flush()
-
- # Our access token is redacted and the failure reason is logged.
- self.assertIn("/?access_token=<redacted>", output.getvalue())
- self.assertIn("Failed123", output.getvalue())
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..b5c3667d2a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return state_group
+ return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
- return {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
+ return defer.succeed(
+ {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+ )
def get_state_group_delta(self, name):
- return None, None
+ return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
- return RoomVersions.V1.identifier
+ return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@@ -202,14 +204,16 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -253,7 +259,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -310,7 +318,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -383,7 +393,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -503,9 +517,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
+ @defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = self.store.store_state_group(
+ sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = self.store.store_state_group(
+ sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
- return self.state.compute_event_context(event)
+ result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ return result
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+ """Create an awaitable that just returns a result."""
+ return result
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 43297b530c..8522c6fc09 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -22,14 +22,12 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
-from tests.test_utils import get_awaitable_result
-
"""
Utility functions for poking events into the storage of the server under test.
"""
-def inject_member_event(
+async def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
@@ -46,7 +44,7 @@ def inject_member_event(
if extra_content:
content.update(extra_content)
- return inject_event(
+ return await inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
@@ -57,7 +55,7 @@ def inject_member_event(
)
-def inject_event(
+async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
@@ -72,37 +70,27 @@ def inject_event(
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
- test_reactor = hs.get_reactor()
-
- event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- d = hs.get_storage().persistence.persist_event(event, context)
- test_reactor.advance(0)
- get_awaitable_result(d)
+ await hs.get_storage().persistence.persist_event(event, context)
return event
-def create_event(
+async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
- test_reactor = hs.get_reactor()
-
if room_version is None:
- d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
- test_reactor.advance(0)
- room_version = get_awaitable_result(d)
+ room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- d = hs.get_event_creation_handler().create_new_client_event(
+ event, context = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
- test_reactor.advance(0)
- event, context = get_awaitable_result(d)
return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f7381b2885..531a9b9118 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
- yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+ yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@defer.inlineCallbacks
def test_filtering(self):
@@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
#
# before we do that, we persist some other events to act as state.
- self.inject_visibility("@admin:hs", "joined")
+ yield self.inject_visibility("@admin:hs", "joined")
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
@@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
for i in range(0, len(events_to_filter)):
@@ -137,10 +137,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
)
- yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -158,11 +160,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -179,11 +183,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
},
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder
+ event, context = yield defer.ensureDeferred(
+ self.event_creation_handler.create_new_client_event(builder)
)
- yield self.storage.persistence.persist_event(event, context)
+ yield defer.ensureDeferred(
+ self.storage.persistence.persist_event(event, context)
+ )
return event
@defer.inlineCallbacks
@@ -265,8 +271,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
storage.main = test_store
storage.state = test_store
- filtered = yield filter_events_for_server(
- test_store, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(test_store, "test_server", events_to_filter)
)
logger.info("Filtering took %f seconds", time.time() - start)
diff --git a/tests/unittest.py b/tests/unittest.py
index 3175a3fa02..68d2586efd 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
- event_injection.inject_member_event(self.hs, room, user, membership)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room, user, membership)
+ )
class FederatingHomeserverTestCase(HomeserverTestCase):
diff --git a/tests/utils.py b/tests/utils.py
index 4d17355a5c..b33b6860d4 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -638,14 +638,8 @@ class DeferredMockCallable(object):
)
-@defer.inlineCallbacks
-def create_room(hs, room_id, creator_id):
+async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
-
- Args:
- hs
- room_id (str)
- creator_id (str)
"""
persistence_store = hs.get_storage().persistence
@@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
- yield store.store_room(
+ await store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=False,
@@ -671,6 +665,6 @@ def create_room(hs, room_id, creator_id):
},
)
- event, context = yield event_creation_handler.create_new_client_event(builder)
+ event, context = await event_creation_handler.create_new_client_event(builder)
- yield persistence_store.persist_event(event, context)
+ await persistence_store.persist_event(event, context)
|