summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py69
-rw-r--r--tests/api/test_filtering.py36
-rw-r--r--tests/appservice/test_appservice.py89
-rw-r--r--tests/appservice/test_scheduler.py19
-rw-r--r--tests/crypto/test_keyring.py50
-rw-r--r--tests/federation/test_complexity.py118
-rw-r--r--tests/federation/test_federation_sender.py10
-rw-r--r--tests/handlers/test_appservice.py7
-rw-r--r--tests/handlers/test_directory.py5
-rw-r--r--tests/handlers/test_identity.py116
-rw-r--r--tests/handlers/test_profile.py16
-rw-r--r--tests/handlers/test_register.py108
-rw-r--r--tests/handlers/test_stats.py116
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/handlers/test_user_directory.py157
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py2
-rw-r--r--tests/http/test_fedclient.py50
-rw-r--r--tests/module_api/test_api.py146
-rw-r--r--tests/push/test_http.py12
-rw-r--r--tests/replication/_base.py6
-rw-r--r--tests/replication/slave/storage/test_events.py6
-rw-r--r--tests/replication/test_federation_sender_shard.py13
-rw-r--r--tests/rest/admin/test_admin.py4
-rw-r--r--tests/rest/admin/test_room.py61
-rw-r--r--tests/rest/admin/test_user.py10
-rw-r--r--tests/rest/client/test_identity.py145
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/test_room_access_rules.py1066
-rw-r--r--tests/rest/client/test_third_party_rules.py170
-rw-r--r--tests/rest/client/third_party_rules.py79
-rw-r--r--tests/rest/client/v1/test_profile.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py6
-rw-r--r--tests/rest/client/v1/utils.py24
-rw-r--r--tests/rest/client/v2_alpha/test_account.py103
-rw-r--r--tests/rest/client/v2_alpha/test_register.py207
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/rest/test_health.py34
-rw-r--r--tests/rulecheck/__init__.py14
-rw-r--r--tests/rulecheck/test_domainrulecheck.py334
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py4
-rw-r--r--tests/storage/test__base.py16
-rw-r--r--tests/storage/test_appservice.py6
-rw-r--r--tests/storage/test_background_update.py8
-rw-r--r--tests/storage/test_base.py22
-rw-r--r--tests/storage/test_cleanup_extrems.py12
-rw-r--r--tests/storage/test_client_ips.py26
-rw-r--r--tests/storage/test_directory.py32
-rw-r--r--tests/storage/test_end_to_end_keys.py12
-rw-r--r--tests/storage/test_event_federation.py16
-rw-r--r--tests/storage/test_event_push_actions.py30
-rw-r--r--tests/storage/test_id_generators.py14
-rw-r--r--tests/storage/test_main.py4
-rw-r--r--tests/storage/test_monthly_active_users.py23
-rw-r--r--tests/storage/test_profile.py8
-rw-r--r--tests/storage/test_purge.py8
-rw-r--r--tests/storage/test_redaction.py8
-rw-r--r--tests/storage/test_room.py30
-rw-r--r--tests/storage/test_roommember.py12
-rw-r--r--tests/storage/test_state.py76
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/test_federation.py2
-rw-r--r--tests/test_server.py45
-rw-r--r--tests/test_state.py14
-rw-r--r--tests/test_types.py22
-rw-r--r--tests/test_visibility.py26
-rw-r--r--tests/unittest.py30
-rw-r--r--tests/util/test_retryutils.py44
-rw-r--r--tests/utils.py22
69 files changed, 3428 insertions, 576 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index 0bfb86bf1f..5d45689c8c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) + self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.is_support_user = Mock(return_value=defer.succeed(False)) @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" @@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") @@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + # This just needs to return a truth-y value. + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.failureResultOf(d, AuthError) @defer.inlineCallbacks def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org", "device_id": "device"} + return_value=defer.succeed( + {"name": "@baldrick:matrix.org", "device_id": "device"} + ) ) user_id = "@baldrick:matrix.org" @@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase): def get_user(tok): if token != tok: - return None - return { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + return defer.succeed(None) + return defer.succeed( + { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={"is_guest": False}) + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart + "2", filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events=events) @@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events) @@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals( user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 + yield defer.ensureDeferred( + self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) ) ), ) @@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase): user_localpart=user_localpart, user_filter=user_filter_json ) - filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 4003869ed6..236b608d58 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_member_is_checked(self): @@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_match(self): @@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_no_match(self): @@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_alias_match(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#irc_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) def test_non_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( @@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#xmpp_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertFalse((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertFalse( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_regex_multiple_matches(self): @@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_barfoo:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_interested_in_self(self): @@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) - self.store.get_users_in_room.return_value = [ - "@alice:here", - "@irc_fo:here", # AS user - "@bob:here", - ] - self.store.get_aliases_for_room.return_value = [] + # Note that @irc_fo:here is the AS user. + self.store.get_users_in_room.return_value = defer.succeed( + ["@alice:here", "@irc_fo:here", "@bob:here"] + ) + self.store.get_aliases_for_room.return_value = defer.succeed([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( - (yield self.service.is_interested(event=self.event, store=self.store)) + ( + yield defer.ensureDeferred( + self.service.is_interested(event=self.event, store=self.store) + ) + ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 52f89d3f83..68a4caabbf 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -25,6 +25,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest +from tests.test_utils import make_awaitable from ..utils import MockClock @@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.get_appservice_state = Mock( return_value=defer.succeed(ApplicationServiceState.UP) ) - txn.send = Mock(return_value=defer.succeed(True)) + txn.send = Mock(return_value=make_awaitable(True)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): return_value=defer.succeed(ApplicationServiceState.UP) ) self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=defer.succeed(False)) # fails to send + txn.send = Mock(return_value=make_awaitable(False)) # fails to send self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events @@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=True) + txn.send = Mock(return_value=make_awaitable(True)) + txn.complete.return_value = make_awaitable(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=False) + txn.send = Mock(return_value=make_awaitable(False)) + txn.complete.return_value = make_awaitable(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=True) # successfully send the txn + txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..0d4b05304b 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -40,6 +40,7 @@ from synapse.logging.context import ( from synapse.storage.keys import FetchKeyResult from tests import unittest +from tests.test_utils import make_awaitable class MockPerspectiveServer(object): @@ -102,11 +103,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): } persp_deferred = defer.Deferred() - @defer.inlineCallbacks - def get_perspectives(**kwargs): + async def get_perspectives(**kwargs): self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): - yield persp_deferred + await persp_deferred return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -202,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): with a null `ts_valid_until_ms` """ mock_fetcher = keyring.KeyFetcher() - mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) + mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) @@ -245,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key(1) - def get_keys(keys_to_fetch): + async def get_keys(keys_to_fetch): # there should only be one request object (with the max validity) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher = keyring.KeyFetcher() mock_fetcher.get_keys = Mock(side_effect=get_keys) @@ -282,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): """If the first fetcher cannot provide a recent enough key, we fall back""" key1 = signedjson.key.generate_signing_key(1) - def get_keys1(keys_to_fetch): + async def get_keys1(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) - } - } - ) + return { + "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} + } - def get_keys2(keys_to_fetch): + async def get_keys2(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher1 = keyring.KeyFetcher() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) @@ -355,7 +347,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): } signedjson.sign.sign_json(response, SERVER_NAME, testkey) - def get_json(destination, path, **kwargs): + async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) self.assertEqual(path, "/_matrix/key/v2/server/key1") return response @@ -444,7 +436,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect a perspectives-server key query """ - def post_json(destination, path, data, **kwargs): + async def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(path, "/_matrix/key/v2/query") @@ -580,14 +572,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # remove the perspectives server's signature response = build_response() del response["signatures"][self.mock_perspective_server.server_name] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") # remove the origin server's signature response = build_response() del response["signatures"][SERVER_NAME] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0c9987be54..b8ca118716 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -78,9 +79,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # Check whether an admin can join if option "admins_can_join" is undefined, + # this option defaults to false, so the join should fail. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -116,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -141,3 +173,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): f = self.get_failure(d, SynapseError) self.assertEqual(f.value.code, 400) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + +class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): + # Test the behavior of joining rooms which exceed the complexity if option + # limit_remote_rooms.admins_can_join is True. + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["limit_remote_rooms"] = { + "enabled": True, + "complexity": 0.05, + "admins_can_join": True, + } + return config + + def test_join_too_large_no_admin(self): + # A user which is not an admin should not be able to join a remote room + # which is too complex. + + u1 = self.register_user("u1", "pass") + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # An admin should be able to join rooms where a complexity check fails. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=make_awaitable(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request success since the user is an admin + self.get_success(d) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d1bd18da39..5f512ff8bf 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py
@@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): receipt = ReadReceipt( "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() mock_send_transaction.assert_not_called() diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ebabe9a7d6..2a0b7c1b56 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from tests.test_utils import make_awaitable from tests.utils import MockClock from .. import unittest @@ -117,9 +118,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_interested_in_alias=False), ] - self.mock_as_api.query_alias.return_value = defer.succeed(True) + self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias.return_value = defer.succeed( + self.mock_store.get_association_from_room_alias.return_value = make_awaitable( Mock(room_id=room_id, servers=servers) ) @@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def _mkservice(self, is_interested): service = Mock() - service.is_interested.return_value = defer.succeed(is_interested) + service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" return service diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@ from mock import Mock -from twisted.internet import defer - import synapse import synapse.api.errors from synapse.api.constants import EventTypes @@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): @@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py new file mode 100644
index 0000000000..0ab0356109 --- /dev/null +++ b/tests/handlers/test_identity.py
@@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account + +from tests import unittest + + +class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.address = "test@test" + self.is_server_name = "testis" + self.is_server_url = "https://testis" + self.rewritten_is_url = "https://int.testis" + + config = self.default_config() + config["trusted_third_party_id_servers"] = [self.is_server_name] + config["rewrite_identity_server_urls"] = { + self.is_server_url: self.rewritten_is_url + } + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.get_json.side_effect = defer.succeed({}) + mock_http_client.post_json_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_blacklisting_http_client.get_json.side_effect = defer.succeed({}) + mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed( + {"address": self.address, "medium": "email"} + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_handlers().identity_handler.blacklisting_http_client = ( + mock_blacklisting_http_client + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + + def test_rewritten_id_server(self): + """ + Tests that, when validating a 3PID association while rewriting the IS's server + name: + * the bind request is done against the rewritten hostname + * the original, non-rewritten, server name is stored in the database + """ + handler = self.hs.get_handlers().identity_handler + post_json_get_json = handler.blacklisting_http_client.post_json_get_json + store = self.hs.get_datastore() + + creds = {"sid": "123", "client_secret": "some_secret"} + + # Make sure processing the mocked response goes through. + data = self.get_success( + handler.bind_threepid( + client_secret=creds["client_secret"], + sid=creds["sid"], + mxid=self.user_id, + id_server=self.is_server_name, + use_v2=False, + ) + ) + self.assertEqual(data.get("address"), self.address) + + # Check that the request was done against the rewritten server name. + post_json_get_json.assert_called_once_with( + "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,), + { + "sid": creds["sid"], + "client_secret": creds["client_secret"], + "mxid": self.user_id, + }, + headers={}, + ) + + # Check that the original server name is saved in the database instead of the + # rewritten one. + id_servers = self.get_success( + store.get_id_servers_user_bound(self.user_id, "email", self.address) + ) + self.assertEqual(id_servers, [self.is_server_name]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f1347cd25..655c1393b7 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -63,14 +64,12 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) - self.handler = hs.get_profile_handler() self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) displayname = yield defer.ensureDeferred( self.handler.get_displayname(self.frank) @@ -111,7 +110,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1) self.assertEquals( (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", @@ -138,7 +137,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_other_name(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) @@ -156,8 +155,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") - yield self.store.set_profile_displayname("caroline", "Caroline") + yield self.store.set_profile_displayname("caroline", "Caroline", 1) response = yield defer.ensureDeferred( self.query_handlers["profile"]( @@ -170,7 +168,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1 ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -211,7 +209,7 @@ class ProfileTestCase(unittest.TestCase): # Setting displayname for the first time is allowed yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.localpart, "http://my.server/me.png", 1 ) self.assertEquals( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 6d45c4b233..e951a62a6d 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -20,8 +20,14 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha.register import ( + _map_email_to_displayname, + register_servlets, +) from synapse.types import RoomAlias, UserID, create_requester +from tests.server import FakeChannel from tests.unittest import override_config from .. import unittest @@ -35,6 +41,10 @@ class RegistrationHandlers(object): class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ + servlets = [ + register_servlets, + ] + def make_homeserver(self, reactor, clock): hs_config = self.default_config() @@ -474,6 +484,104 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_email_to_displayname_mapping(self): + """Test that custom emails are mapped to new user displaynames correctly""" + self._check_mapping( + "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]" + ) + + self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]") + + self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]") + + # Multibyte unicode characters + self._check_mapping( + "j\u030a\u0065an-poppy.seed@example.com", + "J\u030a\u0065an-Poppy Seed [Example]", + ) + + def _check_mapping(self, i, expected): + result = _map_email_to_displayname(i) + self.assertEqual(result, expected) + + @override_config( + { + "bind_new_user_emails_to_sydent": "https://is.example.com", + "registrations_require_3pid": ["email"], + "account_threepid_delegates": {}, + "email": { + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + }, + "public_baseurl": "http://localhost", + } + ) + def test_user_email_bound_via_sydent_internal_api(self): + """Tests that emails are bound after registration if this option is set""" + # Register user with an email address + email = "alice@example.com" + + # Mock Synapse's threepid validator + get_threepid_validation_session = Mock( + return_value=defer.succeed( + {"medium": "email", "address": email, "validated_at": 0} + ) + ) + self.store.get_threepid_validation_session = get_threepid_validation_session + delete_threepid_session = Mock(return_value=defer.succeed(None)) + self.store.delete_threepid_session = delete_threepid_session + + # Mock Synapse's http json post method to check for the internal bind call + post_json_get_json = Mock(return_value=defer.succeed(None)) + self.hs.get_simple_http_client().post_json_get_json = post_json_get_json + + # Retrieve a UIA session ID + channel = self.uia_register( + 401, {"username": "alice", "password": "nobodywillguessthis"} + ) + session_id = channel.json_body["session"] + + # Register our email address using the fake validation session above + channel = self.uia_register( + 200, + { + "username": "alice", + "password": "nobodywillguessthis", + "auth": { + "session": session_id, + "type": "m.login.email.identity", + "threepid_creds": {"sid": "blabla", "client_secret": "blablabla"}, + }, + }, + ) + self.assertEqual(channel.json_body["user_id"], "@alice:test") + + # Check that a bind attempt was made to our fake identity server + post_json_get_json.assert_called_with( + "https://is.example.com/_matrix/identity/internal/bind", + {"address": "alice@example.com", "medium": "email", "mxid": "@alice:test"}, + ) + + # Check that we stored a mapping of this bind + bound_threepids = self.get_success( + self.store.user_get_bound_threepids("@alice:test") + ) + self.assertListEqual(bound_threepids, [{"medium": "email", "address": email}]) + + def uia_register(self, expected_response: int, body: dict) -> FakeChannel: + """Make a register request.""" + request, channel = self.make_request( + "POST", "register", body + ) # type: SynapseRequest, FakeChannel + self.render(request) + + self.assertEqual(request.code, expected_response) + return channel + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..4b627dac00 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py
@@ -15,14 +15,20 @@ from synapse.rest import admin from synapse.rest.client.v1 import login, room -from synapse.storage.data_stores.main import stats +from synapse.storage.databases.main import stats from tests import unittest # The expected number of state events in a fresh public room. EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5 + # The expected number of state events in a fresh private room. -EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 +# +# Note: we increase this by 2 on the dinsic branch as we send +# a "im.vector.room.access_rules" state event into new private rooms, +# and an encryption state event as all private rooms are encrypted +# by default +EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7 class StatsRoomTests(unittest.HomeserverTestCase): @@ -42,36 +48,36 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms", + "update_name": "populate_stats_process_rooms_2", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", + "depends_on": "populate_stats_process_rooms_2", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +88,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store.db.simple_select_list( + return self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +102,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store.db.simple_select_one( + self.store.db_pool.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -109,10 +115,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_initial_room(self): @@ -146,10 +152,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r = self.get_success(self.get_all_room_state()) @@ -186,9 +192,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_update_one( + self.store.db_pool.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -196,17 +202,17 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now, before the table is actually ingested, add some more events. @@ -217,28 +223,31 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", - {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, + { + "update_name": "populate_stats_process_rooms_2", + "progress_json": "{}", + }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", + "depends_on": "populate_stats_process_rooms_2", }, ) ) - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) self.reactor.advance(86401) @@ -346,6 +355,37 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) + def test_updating_profile_information_does_not_increase_joined_members_count(self): + """ + Check that the joined_members count does not increase when a user changes their + profile information (which is done by sending another join membership event into + the room. + """ + self._perform_background_initial_update() + + # Create a user and room + u1 = self.register_user("u1", "pass") + u1token = self.login("u1", "pass") + r1 = self.helper.create_room_as(u1, tok=u1token) + + # Get the current room stats + r1stats_ante = self._get_current_stats("room", r1) + + # Send a profile update into the room + new_profile = {"displayname": "bob"} + self.helper.change_membership( + r1, u1, u1, "join", extra_data=new_profile, tok=u1token + ) + + # Get the new room stats + r1stats_post = self._get_current_stats("room", r1) + + # Ensure that the user count did not changed + self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"]) + self.assertEqual( + r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"] + ) + def test_send_state_event_nonoverwriting(self): """ When we send a non-overwriting state event, it increments total_events AND current_state_events @@ -669,15 +709,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store.db.simple_delete( + self.store.db_pool.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -689,29 +729,29 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms", + "update_name": "populate_stats_process_rooms_2", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", + "depends_on": "populate_stats_process_rooms_2", }, ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -722,10 +762,10 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) r1stats_complete = self._get_current_stats("room", r1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5878f74175..b7d0adb10e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - return defer.succeed(None) + return None hs.get_auth().check_user_in_room = check_user_in_room diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 23fcc372dd..46c3810e70 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client.v2_alpha import account, account_validity, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest @@ -339,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -350,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -362,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -374,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -384,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -394,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -437,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() @@ -476,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self._add_background_updates() while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) shares_private = self.get_users_who_share_private_rooms() @@ -549,3 +549,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) + + +class UserInfoTestCase(unittest.FederatingHomeserverTestCase): + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + account.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # Set accounts to expire after a week + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + return config + + def prepare(self, reactor, clock, hs): + super(UserInfoTestCase, self).prepare(reactor, clock, hs) + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def test_user_info(self): + """Test /users/info for local users from the Client-Server API""" + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request info about each user from user_three + request, channel = self.make_request( + "POST", + path="/_matrix/client/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + access_token=user_three_token, + shorthand=False, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def test_user_info_federation(self): + """Test that /users/info can be called from the Federation API, and + and that we can query remote users from the Client-Server API + """ + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request information about our local users from the perspective of a remote server + request, channel = self.make_request( + "POST", + path="/_matrix/federation/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def setup_test_users(self): + """Create an admin user and three test users, each with a different state""" + + # Create an admin user to expire other users with + self.register_user("admin", "adminpassword", admin=True) + admin_token = self.login("admin", "adminpassword") + + # Create three users + user_one = self.register_user("alice", "pass") + user_one_token = self.login("alice", "pass") + user_two = self.register_user("bob", "pass") + user_three = self.register_user("carl", "pass") + user_three_token = self.login("carl", "pass") + + # Deactivate user_one + self.deactivate(user_one, user_one_token) + + # Expire user_two + self.expire(user_two, admin_token) + + # Do nothing to user_three + + return user_one, user_two, user_three, user_three_token + + def expire(self, user_id_to_expire, admin_tok): + url = "/_matrix/client/unstable/admin/account_validity/validity" + request_data = { + "user_id": user_id_to_expire, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request, channel = self.make_request( + "POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def deactivate(self, user_id, tok): + request_data = { + "auth": {"type": "m.login.password", "user": user_id, "password": "pass"}, + "erase": False, + } + request, channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.render(request) + self.assertEqual(request.code, 200) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 69945a8f98..db260d599e 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -101,7 +101,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=self.tls_factory, + tls_client_options_factory=FederationPolicyForHTTPS(config), user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..ac598249e4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..9c778a0e45 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -12,16 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock +from synapse.events import EventBase from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.types import create_requester from tests.unittest import HomeserverTestCase class ModuleApiTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler()) + self.event_creation_handler = homeserver.get_event_creation_handler() def test_can_register_user(self): """Tests that an external module can register a user""" @@ -52,3 +64,137 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the displayname was assigned displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + + def test_sending_events_into_room(self): + """Tests that a module can send events into a room""" + # Mock out create_and_send_nonmember_event to check whether events are being sent + self.event_creation_handler.create_and_send_nonmember_event = Mock( + spec=[], + side_effect=self.event_creation_handler.create_and_send_nonmember_event, + ) + + # Create a user and room to play with + user_id = self.register_user("summer", "monkey") + tok = self.login("summer", "monkey") + room_id = self.helper.create_room_as(user_id, tok=tok) + + # Create and send a non-state event + content = {"body": "I am a puppet", "msgtype": "m.text"} + event_dict = { + "room_id": room_id, + "type": "m.room.message", + "content": content, + "sender": user_id, + } + event = self.get_success( + self.module_api.create_and_send_event_into_room(event_dict) + ) # type: EventBase + self.assertEqual(event.sender, user_id) + self.assertEqual(event.type, "m.room.message") + self.assertEqual(event.room_id, room_id) + self.assertFalse(hasattr(event, "state_key")) + self.assertDictEqual(event.content, content) + + # Check that the event was sent + self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( + create_requester(user_id), event_dict, ratelimit=False, + ) + + # Create and send a state event + content = { + "events_default": 0, + "users": {user_id: 100}, + "state_default": 50, + "users_default": 0, + "events": {"test.event.type": 25}, + } + event_dict = { + "room_id": room_id, + "type": "m.room.power_levels", + "content": content, + "sender": user_id, + "state_key": "", + } + event = self.get_success( + self.module_api.create_and_send_event_into_room(event_dict) + ) # type: EventBase + self.assertEqual(event.sender, user_id) + self.assertEqual(event.type, "m.room.power_levels") + self.assertEqual(event.room_id, room_id) + self.assertEqual(event.state_key, "") + self.assertDictEqual(event.content, content) + + # Check that the event was sent + self.event_creation_handler.create_and_send_nonmember_event.assert_called_with( + create_requester(user_id), + { + "type": "m.room.power_levels", + "content": content, + "room_id": room_id, + "sender": user_id, + "state_key": "", + }, + ratelimit=False, + ) + + # Check that we can't send membership events + content = { + "membership": "leave", + } + event_dict = { + "room_id": room_id, + "type": "m.room.member", + "content": content, + "sender": user_id, + "state_key": user_id, + } + self.get_failure( + self.module_api.create_and_send_event_into_room(event_dict), Exception + ) + + def test_public_rooms(self): + """Tests that a room can be added and removed from the public rooms list, + as well as have its public rooms directory state queried. + """ + # Create a user and room to play with + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + room_id = self.helper.create_room_as(user_id, tok=tok) + + # The room should not currently be in the public rooms directory + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertFalse(is_in_public_rooms) + + # Let's try adding it to the public rooms directory + self.get_success( + self.module_api.public_room_list_manager.add_room_to_public_room_list( + room_id + ) + ) + + # And checking whether it's in there... + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertTrue(is_in_public_rooms) + + # Let's remove it again + self.get_success( + self.module_api.public_room_list_manager.remove_room_from_public_room_list( + room_id + ) + ) + + # Should be gone + is_in_public_rooms = self.get_success( + self.module_api.public_room_list_manager.room_is_in_public_room_list( + room_id + ) + ) + self.assertFalse(is_in_public_rooms) diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..2f56cacc7a 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -346,8 +346,8 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_mention(self): """ @@ -418,8 +418,8 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") def test_sends_high_priority_for_atroom(self): """ @@ -497,5 +497,5 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(self.push_attempts), 2) self.assertEqual(self.push_attempts[1][1], "example.com") - # check that this is low-priority - self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + # check that this is high-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 06575ba0a6..ae60874ec3 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -65,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Since we use sqlite in memory databases we need to make sure the # databases objects are the same. - self.worker_hs.get_datastore().db = hs.get_datastore().db + self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler @@ -198,7 +198,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer = self.hs.get_replication_streamer() store = self.hs.get_datastore() - self.database = store.db + self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" @@ -254,7 +254,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) store = worker_hs.get_datastore() - store.db._db_pool = self.database._db_pool + store.db_pool._db_pool = self.database_pool._db_pool repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..0b5204654c 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py
@@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = self.get_success(state_handler.compute_event_context(event)) - self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, {user_id: actions for user_id, actions in push_actions} + ) ) return event, context diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8d4dbf232e..83f9aa291c 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py
@@ -16,8 +16,6 @@ import logging from mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource @@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index b1a4decced..0f1144fe1e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + async def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + return await make_deferred_yieldable(d) client = Mock() client.get_file = get_file diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ba8552c29f..408c568a27 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -283,6 +283,23 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + body = json.dumps({"purge": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + def test_purge_room_and_block(self): """Test to purge a room and block it. Members will not be moved to a new room and will not receive a message. @@ -297,7 +314,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True}) + body = json.dumps({"block": True, "purge": True}) request, channel = self.make_request( "POST", @@ -331,7 +348,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": False}) + body = json.dumps({"block": False, "purge": True}) request, channel = self.make_request( "POST", @@ -351,6 +368,42 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=False) self._has_no_members(self.room_id) + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to @@ -513,7 +566,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", @@ -614,7 +667,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f16eef15f7..17d0aae2e9 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -20,6 +20,8 @@ import urllib.parse from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() # Set monthly active users to the limit - store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit self.get_failure( @@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..4224b0a92e 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@ import json +from mock import Mock + +from twisted.internet import defer + import synapse.rest.admin from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import account from tests import unittest -class IdentityTestCase(unittest.HomeserverTestCase): +class IdentityDisabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts fail when the HS's config disallows them.""" servlets = [ + account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, @@ -32,24 +39,111 @@ class IdentityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["trusted_third_party_id_servers"] = ["testis"] config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_disabled(self): + request, channel = self.make_request( + b"POST", "/createRoom", b"{}", access_token=self.tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + room_id = channel.json_body["room_id"] + + params = { + "id_server": "testis", + "medium": "email", + "address": "test@example.com", + } + request_data = json.dumps(params) + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") + request, channel = self.make_request( + b"POST", request_url, request_data, access_token=self.tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) + request, channel = self.make_request("GET", url, access_token=self.tok) + self.render(request) + self.assertEqual(channel.result["code"], b"403", channel.result) + + def test_3pid_bulk_lookup_disabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok + ) + self.render(request) + self.assertEqual(channel.result["code"], b"403", channel.result) + + +class IdentityEnabledTestCase(unittest.HomeserverTestCase): + """Tests that 3PID lookup attempts succeed when the HS's config allows them.""" + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["enable_3pid_lookup"] = True + config["trusted_third_party_id_servers"] = ["testis"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.get_json.return_value = defer.succeed((200, "{}")) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_handlers().identity_handler.http_client = mock_http_client + + return self.hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + def test_3pid_invite_enabled(self): request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok + b"POST", "/createRoom", b"{}", access_token=self.tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) room_id = channel.json_body["room_id"] + # Replace the blacklisting SimpleHttpClient with our mock + self.hs.get_room_member_handler().simple_http_client = Mock( + spec=["get_json", "post_json_get_json"] + ) + self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed( + (200, "{}") + ) + params = { "id_server": "testis", "medium": "email", @@ -58,7 +152,44 @@ class IdentityTestCase(unittest.HomeserverTestCase): request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok + b"POST", request_url, request_data, access_token=self.tok ) self.render(request) - self.assertEquals(channel.result["code"], b"403", channel.result) + + get_json = self.hs.get_handlers().identity_handler.http_client.get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + {"address": "test@example.com", "medium": "email"}, + ) + + def test_3pid_lookup_enabled(self): + url = ( + "/_matrix/client/unstable/account/3pid/lookup" + "?id_server=testis&medium=email&address=foo@bar.baz" + ) + request, channel = self.make_request("GET", url, access_token=self.tok) + self.render(request) + + get_json = self.hs.get_simple_http_client().get_json + get_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/lookup", + {"address": "foo@bar.baz", "medium": "email"}, + ) + + def test_3pid_bulk_lookup_enabled(self): + url = "/_matrix/client/unstable/account/3pid/bulk_lookup" + data = { + "id_server": "testis", + "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]], + } + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.tok + ) + self.render(request) + + post_json = self.hs.get_simple_http_client().post_json_get_json + post_json.assert_called_once_with( + "https://testis/_matrix/identity/api/v1/bulk_lookup", + {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]}, + ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index e54ffea150..cc264cf0b5 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["default_room_version"] = "1" config["retention"] = { "enabled": True, "default_policy": { @@ -203,6 +204,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + config["default_room_version"] = "1" config["retention"] = { "enabled": True, } diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py new file mode 100644
index 0000000000..de7856fba9 --- /dev/null +++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1066 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import random +import string +from typing import Optional + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.rest import admin +from synapse.rest.client.v1 import directory, login, room +from synapse.third_party_rules.access_rules import ( + ACCESS_RULES_TYPE, + AccessRules, + RoomAccessRules, +) +from synapse.types import JsonDict, create_requester + +from tests import unittest + + +class RoomAccessTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["third_party_event_rules"] = { + "module": "synapse.third_party_rules.access_rules.RoomAccessRules", + "config": { + "domains_forbidden_when_restricted": ["forbidden_domain"], + "id_server": "testis", + }, + } + config["trusted_third_party_id_servers"] = ["testis"] + + def send_invite(destination, room_id, event_id, pdu): + return defer.succeed(pdu) + + def get_json(uri, args={}, headers=None): + address_domain = args["address"].split("@")[1] + return defer.succeed({"hs": address_domain}) + + def post_json_get_json(uri, post_json, args={}, headers=None): + token = "".join(random.choice(string.ascii_letters) for _ in range(10)) + return defer.succeed( + { + "token": token, + "public_keys": [ + { + "public_key": "serverpublickey", + "key_validity_url": "https://testis/pubkey/isvalid", + }, + { + "public_key": "phemeralpublickey", + "key_validity_url": "https://testis/pubkey/ephemeral/isvalid", + }, + ], + "display_name": "f...@b...", + } + ) + + mock_federation_client = Mock(spec=["send_invite"]) + mock_federation_client.send_invite.side_effect = send_invite + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"],) + # Mocking the response for /info on the IS API. + mock_http_client.get_json.side_effect = get_json + # Mocking the response for /store-invite on the IS API. + mock_http_client.post_json_get_json.side_effect = post_json_get_json + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + simple_http_client=mock_http_client, + ) + + # TODO: This class does not use a singleton to get it's http client + # This should be fixed for easier testing + # https://github.com/matrix-org/synapse-dinsic/issues/26 + self.hs.get_handlers().identity_handler.blacklisting_http_client = ( + mock_http_client + ) + + self.third_party_event_rules = self.hs.get_third_party_event_rules() + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.restricted_room = self.create_room() + self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED) + self.direct_rooms = [ + self.create_room(direct=True), + self.create_room(direct=True), + self.create_room(direct=True), + ] + + self.invitee_id = self.register_user("invitee", "test") + self.invitee_tok = self.login("invitee", "test") + + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + ) + + def test_create_room_no_rule(self): + """Tests that creating a room with no rule will set the default.""" + room_id = self.create_room() + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.RESTRICTED) + + def test_create_room_direct_no_rule(self): + """Tests that creating a direct room with no rule will set the default.""" + room_id = self.create_room(direct=True) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.DIRECT) + + def test_create_room_valid_rule(self): + """Tests that creating a room with a valid rule will set the right.""" + room_id = self.create_room(rule=AccessRules.UNRESTRICTED) + rule = self.current_rule_in_room(room_id) + + self.assertEqual(rule, AccessRules.UNRESTRICTED) + + def test_create_room_invalid_rule(self): + """Tests that creating a room with an invalid rule will set fail.""" + self.create_room(rule=AccessRules.DIRECT, expected_code=400) + + def test_create_room_direct_invalid_rule(self): + """Tests that creating a direct room with an invalid rule will fail. + """ + self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400) + + def test_create_room_default_power_level_rules(self): + """Tests that a room created with no power level overrides instead uses the dinum + defaults + """ + room_id = self.create_room(direct=True, rule=AccessRules.DIRECT) + power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok) + + # Inviting another user should require PL50, even in private rooms + self.assertEqual(power_levels["invite"], 50) + # Sending arbitrary state events should require PL100 + self.assertEqual(power_levels["state_default"], 100) + + def test_create_room_fails_on_incorrect_power_level_rules(self): + """Tests that a room created with power levels lower than that required are rejected""" + modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id) + modified_power_levels["invite"] = 0 + modified_power_levels["state_default"] = 50 + + self.create_room( + direct=True, + rule=AccessRules.DIRECT, + initial_state=[ + {"type": "m.room.power_levels", "content": modified_power_levels} + ], + expected_code=400, + ) + + def test_existing_room_can_change_power_levels(self): + """Tests that a room created with default power levels can have their power levels + dropped after room creation + """ + # Creates a room with the default power levels + room_id = self.create_room( + direct=True, rule=AccessRules.DIRECT, expected_code=200, + ) + + # Attempt to drop invite and state_default power levels after the fact + room_power_levels = self.helper.get_state( + room_id, "m.room.power_levels", self.tok + ) + room_power_levels["invite"] = 0 + room_power_levels["state_default"] = 50 + self.helper.send_state( + room_id, "m.room.power_levels", room_power_levels, self.tok + ) + + def test_public_room(self): + """Tests that it's only possible to have a room listed in the public room list + if the access rule is restricted. + """ + # Creating a room with the public_chat preset should succeed and set the access + # rule to restricted. + preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT) + self.assertEqual( + self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED + ) + + # Creating a room with the public join rule in its initial state should succeed + # and set the access rule to restricted. + init_state_room_id = self.create_room( + initial_state=[ + { + "type": "m.room.join_rules", + "content": {"join_rule": JoinRules.PUBLIC}, + } + ] + ) + self.assertEqual( + self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED + ) + + # List preset_room_id in the public room list + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,), + {"visibility": "public"}, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # List init_state_room_id in the public room list + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,), + {"visibility": "public"}, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # Changing access rule to unrestricted should fail. + self.change_rule_in_room( + preset_room_id, AccessRules.UNRESTRICTED, expected_code=403 + ) + self.change_rule_in_room( + init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403 + ) + + # Changing access rule to direct should fail. + self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403) + self.change_rule_in_room( + init_state_room_id, AccessRules.DIRECT, expected_code=403 + ) + + # Creating a new room with the public_chat preset and an access rule of direct + # should fail. + self.create_room( + preset=RoomCreationPreset.PUBLIC_CHAT, + rule=AccessRules.DIRECT, + expected_code=400, + ) + + # Changing join rule to public in an direct room should fail. + self.change_join_rule_in_room( + self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403 + ) + + def test_restricted(self): + """Tests that in restricted mode we're unable to invite users from blacklisted + servers but can invite other users. + + Also tests that the room can be published to, and removed from, the public room + list. + """ + # We can't invite a user from a forbidden HS. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=403, + ) + + # We can invite a user which HS isn't forbidden. + self.helper.invite( + room=self.restricted_room, + src=self.user_id, + targ="@test:allowed_domain", + tok=self.tok, + expect_code=200, + ) + + # We can't send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.restricted_room, + expected_code=403, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.restricted_room, + expected_code=200, + ) + + # We are allowed to publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # We are allowed to remove the room from the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room + data = {"visibility": "private"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + def test_direct(self): + """Tests that, in direct mode, other users than the initial two can't be invited, + but the following scenario works: + * invited user joins the room + * invited user leaves the room + * room creator re-invites invited user + + Tests that a user from a HS that's in the list of forbidden domains (to use + in restricted mode) can be invited. + + Tests that the room cannot be published to the public room list. + """ + not_invited_user = "@not_invited:forbidden_domain" + + # We can't invite a new user to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=403, + ) + + # The invited user can join the room. + self.helper.join( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can leave the room. + self.helper.leave( + room=self.direct_rooms[0], + user=self.invitee_id, + tok=self.invitee_tok, + expect_code=200, + ) + + # The invited user can be re-invited to the room. + self.helper.invite( + room=self.direct_rooms[0], + src=self.user_id, + targ=self.invitee_id, + tok=self.tok, + expect_code=200, + ) + + # If we're alone in the room and have always been the only member, we can invite + # someone. + self.helper.invite( + room=self.direct_rooms[1], + src=self.user_id, + targ=not_invited_user, + tok=self.tok, + expect_code=200, + ) + + # Disable the 3pid invite ratelimiter + burst = self.hs.config.rc_third_party_invite.burst_count + per_second = self.hs.config.rc_third_party_invite.per_second + self.hs.config.rc_third_party_invite.burst_count = 10 + self.hs.config.rc_third_party_invite.per_second = 0.1 + + # We can't send a 3PID invite to a room that already has two members. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[0], + expected_code=403, + ) + + # We can't send a 3PID invite to a room that already has a pending invite. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.direct_rooms[1], + expected_code=403, + ) + + # We can send a 3PID invite to a room in which we've always been the only member. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=200, + ) + + # We can send a 3PID invite to a room in which there's a 3PID invite. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.direct_rooms[2], + expected_code=403, + ) + + self.hs.config.rc_third_party_invite.burst_count = burst + self.hs.config.rc_third_party_invite.per_second = per_second + + # We can't publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0] + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 403, channel.result) + + def test_unrestricted(self): + """Tests that, in unrestricted mode, we can invite whoever we want, but we can + only change the power level of users that wouldn't be forbidden in restricted + mode. + + Tests that the room cannot be published to the public room list. + """ + # We can invite + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + self.helper.invite( + room=self.unrestricted_room, + src=self.user_id, + targ="@test:not_forbidden_domain", + tok=self.tok, + expect_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to a forbidden HS. + self.send_threepid_invite( + address="test@forbidden_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a 3PID invite to an address that is mapped to an HS that's not + # forbidden. + self.send_threepid_invite( + address="test@allowed_domain", + room_id=self.unrestricted_room, + expected_code=200, + ) + + # We can send a power level event that doesn't redefine the default PL or set a + # non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}}, + tok=self.tok, + expect_code=200, + ) + + # We can't send a power level event that redefines the default PL and doesn't set + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={ + "users": {self.user_id: 100, "@test:not_forbidden_domain": 10}, + "users_default": 10, + }, + tok=self.tok, + expect_code=403, + ) + + # We can't send a power level event that doesn't redefines the default PL but sets + # a non-default PL for a user that would be forbidden in restricted mode. + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.PowerLevels, + body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}}, + tok=self.tok, + expect_code=403, + ) + + # We can't publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 403, channel.result) + + def test_change_rules(self): + """Tests that we can only change the current rule from restricted to + unrestricted. + """ + # We can't change the rule from restricted to direct. + self.change_rule_in_room( + room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403 + ) + + # We can change the rule from restricted to unrestricted. + # Note that this changes self.restricted_room to an unrestricted room + self.change_rule_in_room( + room_id=self.restricted_room, + new_rule=AccessRules.UNRESTRICTED, + expected_code=200, + ) + + # We can't change the rule from unrestricted to restricted. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=AccessRules.RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from unrestricted to direct. + self.change_rule_in_room( + room_id=self.unrestricted_room, + new_rule=AccessRules.DIRECT, + expected_code=403, + ) + + # We can't change the rule from direct to restricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=AccessRules.RESTRICTED, + expected_code=403, + ) + + # We can't change the rule from direct to unrestricted. + self.change_rule_in_room( + room_id=self.direct_rooms[0], + new_rule=AccessRules.UNRESTRICTED, + expected_code=403, + ) + + # We can't publish a room to the public room list and then change its rule to + # unrestricted + + # Create a restricted room + test_room_id = self.create_room(rule=AccessRules.RESTRICTED) + + # Publish the room to the public room list + url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id + data = {"visibility": "public"} + + request, channel = self.make_request("PUT", url, data, access_token=self.tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + # Attempt to switch the room to "unrestricted" + self.change_rule_in_room( + room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403 + ) + + # Attempt to switch the room to "direct" + self.change_rule_in_room( + room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403 + ) + + def test_change_room_avatar(self): + """Tests that changing the room avatar is always allowed unless the room is a + direct chat, in which case it's forbidden. + """ + + avatar_content = { + "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394}, + "url": "mxc://example.org/JWEIFJgwEIhweiWJE", + } + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.RoomAvatar, + body=avatar_content, + tok=self.tok, + expect_code=403, + ) + + def test_change_room_name(self): + """Tests that changing the room name is always allowed unless the room is a direct + chat, in which case it's forbidden. + """ + + name_content = {"name": "My super room"} + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.Name, + body=name_content, + tok=self.tok, + expect_code=403, + ) + + def test_change_room_topic(self): + """Tests that changing the room topic is always allowed unless the room is a + direct chat, in which case it's forbidden. + """ + + topic_content = {"topic": "Welcome to this room"} + + self.helper.send_state( + room_id=self.restricted_room, + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.unrestricted_room, + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=200, + ) + + self.helper.send_state( + room_id=self.direct_rooms[0], + event_type=EventTypes.Topic, + body=topic_content, + tok=self.tok, + expect_code=403, + ) + + def test_revoke_3pid_invite_direct(self): + """Tests that revoking a 3PID invite doesn't cause the room access rules module to + confuse the revokation as a new 3PID invite. + """ + invite_token = "sometoken" + + invite_body = { + "display_name": "ker...@exa...", + "public_keys": [ + { + "key_validity_url": "https://validity_url", + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", + }, + { + "key_validity_url": "https://validity_url", + "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I", + }, + ], + "key_validity_url": "https://validity_url", + "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA", + } + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body=invite_body, + tok=self.tok, + ) + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body={}, + tok=self.tok, + ) + + invite_token = "someothertoken" + + self.send_state_with_state_key( + room_id=self.direct_rooms[1], + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body=invite_body, + tok=self.tok, + ) + + def test_check_event_allowed(self): + """Tests that RoomAccessRules.check_event_allowed behaves accordingly. + + It tests that: + * forbidden users cannot join restricted rooms. + * forbidden users can only join unrestricted rooms if they have an invite. + """ + event_creator = self.hs.get_event_creation_handler() + + # Test that forbidden users cannot join restricted rooms + requester = create_requester(self.user_id) + allowed_requester = create_requester("@user:allowed_domain") + forbidden_requester = create_requester("@user:forbidden_domain") + + # Create a join event for a forbidden user + forbidden_join_event, forbidden_join_event_context = self.get_success( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.restricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ) + ) + + # Create a join event for an allowed user + allowed_join_event, allowed_join_event_context = self.get_success( + event_creator.create_event( + allowed_requester, + { + "type": EventTypes.Member, + "room_id": self.restricted_room, + "sender": allowed_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": allowed_requester.user.to_string(), + }, + ) + ) + + # Assert a join event from a forbidden user to a restricted room is rejected + can_join = self.get_success( + self.third_party_event_rules.check_event_allowed( + forbidden_join_event, forbidden_join_event_context + ) + ) + self.assertFalse(can_join) + + # But a join event from an non-forbidden user to a restricted room is allowed + can_join = self.get_success( + self.third_party_event_rules.check_event_allowed( + allowed_join_event, allowed_join_event_context + ) + ) + self.assertTrue(can_join) + + # Test that forbidden users can only join unrestricted rooms if they have an invite + + # Recreate the forbidden join event for the unrestricted room instead + forbidden_join_event, forbidden_join_event_context = self.get_success( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.unrestricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ) + ) + + # A forbidden user without an invite should not be able to join an unrestricted room + can_join = self.get_success( + self.third_party_event_rules.check_event_allowed( + forbidden_join_event, forbidden_join_event_context + ) + ) + self.assertFalse(can_join) + + # However, if we then invite this user... + self.helper.invite( + room=self.unrestricted_room, + src=requester.user.to_string(), + targ=forbidden_requester.user.to_string(), + tok=self.tok, + ) + + # And create another join event, making sure that its context states it's coming + # in after the above invite was made... + forbidden_join_event, forbidden_join_event_context = self.get_success( + event_creator.create_event( + forbidden_requester, + { + "type": EventTypes.Member, + "room_id": self.unrestricted_room, + "sender": forbidden_requester.user.to_string(), + "content": {"membership": Membership.JOIN}, + "state_key": forbidden_requester.user.to_string(), + }, + ) + ) + + # Then the forbidden user should be able to join! + can_join = self.get_success( + self.third_party_event_rules.check_event_allowed( + forbidden_join_event, forbidden_join_event_context + ) + ) + self.assertTrue(can_join) + + def test_freezing_a_room(self): + """Tests that the power levels in a room change to prevent new events from + non-admin users when the last admin of a room leaves. + """ + + def freeze_room_with_id_and_power_levels( + room_id: str, custom_power_levels_content: Optional[JsonDict] = None, + ): + # Invite a user to the room, they join with PL 0 + self.helper.invite( + room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok, + ) + + # Invitee joins the room + self.helper.join( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + if not custom_power_levels_content: + # Retrieve the room's current power levels event content + power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.tok, + ) + else: + power_levels = custom_power_levels_content + + # Override the room's power levels with the given power levels content + self.helper.send_state( + room_id=room_id, + event_type="m.room.power_levels", + body=custom_power_levels_content, + tok=self.tok, + ) + + # Ensure that the invitee leaving the room does not change the power levels + self.helper.leave( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + # Retrieve the new power levels of the room + new_power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.tok, + ) + + # Ensure they have not changed + self.assertDictEqual(power_levels, new_power_levels) + + # Invite the user back again + self.helper.invite( + room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok, + ) + + # Invitee joins the room + self.helper.join( + room=room_id, user=self.invitee_id, tok=self.invitee_tok, + ) + + # Now the admin leaves the room + self.helper.leave( + room=room_id, user=self.user_id, tok=self.tok, + ) + + # Check the power levels again + new_power_levels = self.helper.get_state( + room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok, + ) + + # Ensure that the new power levels prevent anyone but admins from sending + # certain events + self.assertEquals(new_power_levels["state_default"], 100) + self.assertEquals(new_power_levels["events_default"], 100) + self.assertEquals(new_power_levels["kick"], 100) + self.assertEquals(new_power_levels["invite"], 100) + self.assertEquals(new_power_levels["ban"], 100) + self.assertEquals(new_power_levels["redact"], 100) + self.assertDictEqual(new_power_levels["events"], {}) + self.assertDictEqual(new_power_levels["users"], {self.user_id: 100}) + + # Ensure new users entering the room aren't going to immediately become admins + self.assertEquals(new_power_levels["users_default"], 0) + + # Test that freezing a room with the default power level state event content works + room1 = self.create_room() + freeze_room_with_id_and_power_levels(room1) + + # Test that freezing a room with a power level state event that is missing + # `state_default` and `event_default` keys behaves as expected + room2 = self.create_room() + freeze_room_with_id_and_power_levels( + room2, + { + "ban": 50, + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + }, + "invite": 0, + "kick": 50, + "redact": 50, + "users": {self.user_id: 100}, + "users_default": 0, + # Explicitly remove `state_default` and `event_default` keys + }, + ) + + # Test that freezing a room with a power level state event that is *additionally* + # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected + room3 = self.create_room() + freeze_room_with_id_and_power_levels( + room3, + { + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + }, + "users": {self.user_id: 100}, + "users_default": 0, + # Explicitly remove `state_default` and `event_default` keys + # Explicitly remove `ban`, `invite`, `kick` and `redact` keys + }, + ) + + def create_room( + self, + direct=False, + rule=None, + preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT, + initial_state=None, + expected_code=200, + ): + content = {"is_direct": direct, "preset": preset} + + if rule: + content["initial_state"] = [ + {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}} + ] + + if initial_state: + if "initial_state" not in content: + content["initial_state"] = [] + + content["initial_state"] += initial_state + + request, channel = self.make_request( + "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + if expected_code == 200: + return channel.json_body["room_id"] + + def current_rule_in_room(self, room_id): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["rule"] + + def change_rule_in_room(self, room_id, new_rule, expected_code=200): + data = {"rule": new_rule} + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE), + json.dumps(data), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200): + data = {"join_rule": new_join_rule} + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules), + json.dumps(data), + access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + def send_threepid_invite(self, address, room_id, expected_code=200): + params = {"id_server": "testis", "medium": "email", "address": address} + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/invite" % room_id, + json.dumps(params), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def send_state_with_state_key( + self, room_id, event_type, state_key, body, tok, expect_code=200 + ): + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) + + request, channel = self.make_request( + "PUT", path, json.dumps(body), access_token=tok + ) + self.render(request) + + self.assertEqual(channel.code, expect_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py new file mode 100644
index 0000000000..d03e121664 --- /dev/null +++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from typing import Dict + +from mock import Mock + +from synapse.events import EventBase +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.types import Requester, StateMap + +from tests import unittest + +thread_local = threading.local() + + +class ThirdPartyRulesTestModule: + def __init__(self, config: Dict, module_api: ModuleApi): + # keep a record of the "current" rules module, so that the test can patch + # it if desired. + thread_local.rules_module = self + self.module_api = module_api + + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return True + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + return True + + @staticmethod + def parse_config(config): + return config + + +def current_rules_module() -> ThirdPartyRulesTestModule: + return thread_local.rules_module + + +class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["third_party_event_rules"] = { + "module": __name__ + ".ThirdPartyRulesTestModule", + "config": {}, + } + return config + + def prepare(self, reactor, clock, homeserver): + # Create a user and room to play with during the tests + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_third_party_rules(self): + """Tests that a forbidden event is forbidden from being sent, but an allowed one + can be sent. + """ + # patch the rules module with a Mock which will return False for some event + # types + async def check(ev, state): + return ev.type != "foo.bar.forbidden" + + callback = Mock(spec=[], side_effect=check) + current_rules_module().check_event_allowed = callback + + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id, + {}, + access_token=self.tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + callback.assert_called_once() + + # there should be various state events in the state arg: do some basic checks + state_arg = callback.call_args[0][1] + for k in (("m.room.create", ""), ("m.room.member", self.user_id)): + self.assertIn(k, state_arg) + ev = state_arg[k] + self.assertEqual(ev.type, k[0]) + self.assertEqual(ev.state_key, k[1]) + + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id, + {}, + access_token=self.tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_modify_event(self): + """Tests that the module can successfully tweak an event before it is persisted. + """ + # first patch the event checker so that it will modify the event + async def check(ev: EventBase, state): + ev.content = {"x": "y"} + return True + + current_rules_module().check_event_allowed = check + + # now send the event + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + {"x": "x"}, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.result["code"], b"200", channel.result) + event_id = channel.json_body["event_id"] + + # ... and check that it got modified + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "y") + + def test_send_event(self): + """Tests that the module can send an event into a room via the module api""" + content = { + "msgtype": "m.text", + "body": "Hello!", + } + event_dict = { + "room_id": self.room_id, + "type": "m.room.message", + "content": content, + "sender": self.user_id, + } + event = self.get_success( + current_rules_module().module_api.create_and_send_event_into_room( + event_dict + ) + ) # type: EventBase + + self.assertEquals(event.sender, self.user_id) + self.assertEquals(event.room_id, self.room_id) + self.assertEquals(event.type, "m.room.message") + self.assertEquals(event.content, content) diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py deleted file mode 100644
index 7167fc56b6..0000000000 --- a/tests/rest/client/third_party_rules.py +++ /dev/null
@@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest import admin -from synapse.rest.client.v1 import login, room - -from tests import unittest - - -class ThirdPartyRulesTestModule(object): - def __init__(self, config): - pass - - def check_event_allowed(self, event, context): - if event.type == "foo.bar.forbidden": - return False - else: - return True - - @staticmethod - def parse_config(config): - return config - - -class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["third_party_event_rules"] = { - "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule", - "config": {}, - } - - self.hs = self.setup_test_homeserver(config=config) - return self.hs - - def test_third_party_rules(self): - """Tests that a forbidden event is forbidden from being sent, but an allowed one - can be sent. - """ - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - room_id = self.helper.create_room_as(user_id, tok=tok) - - request, channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id, - {}, - access_token=tok, - ) - self.render(request) - self.assertEquals(channel.result["code"], b"200", channel.result) - - request, channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id, - {}, - access_token=tok, - ) - self.render(request) - self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase): profile_handler=self.mock_handler, ) - def _get_user_by_req(request=None, allow_guest=False): - return defer.succeed(synapse.types.create_requester(myid)) + async def _get_user_by_req(request=None, allow_guest=False): + return synapse.types.create_requester(myid) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..ef6b775ed2 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -23,8 +23,6 @@ from urllib import parse as urlparse from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus @@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None self.hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_handlers().federation_handler = Mock() - def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, @@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..8933b560d2 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -88,7 +88,28 @@ class RestHelper(object): expect_code=expect_code, ) - def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): + def change_membership( + self, + room: str, + src: str, + targ: str, + membership: str, + extra_data: dict = {}, + tok: Optional[str] = None, + expect_code: int = 200, + ) -> None: + """ + Send a membership state event into a room. + + Args: + room: The ID of the room to send to + src: The mxid of the event sender + targ: The mxid of the event's target. The state key + membership: The type of membership event + extra_data: Extra information to include in the content of the event + tok: The user access token to use + expect_code: The expected HTTP response code + """ temp_id = self.auth_user_id self.auth_user_id = src @@ -97,6 +118,7 @@ class RestHelper(object): path = path + "?access_token=%s" % tok data = {"membership": membership} + data.update(extra_data) request, channel = make_request( self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..0a51aeff92 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import re from email.parser import Parser +from typing import Optional import pkg_resources @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest +from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): @@ -668,16 +669,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def _request_token(self, email, client_secret): + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + "POST", b"account/3pid/email/requestToken", body, ) self.render(request) - self.assertEquals(200, channel.code, channel.result) + self.assertEquals(expect_code, channel.code, channel.result) - return channel.json_body["sid"] + return channel.json_body.get("sid") def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..fce79b38f2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime import json import os +from mock import Mock + import pkg_resources +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -87,14 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.json_body["error"], "Invalid password") - def test_POST_bad_username(self): - request_data = json.dumps({"username": 777, "password": "monkey"}) - request, channel = self.make_request(b"POST", self.url, request_data) - self.render(request) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") - def test_POST_user_valid(self): user_id = "@kermit:test" device_id = "frogfone" @@ -116,8 +112,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) + @override_config({"enable_registration": False}) def test_POST_disabled_registration(self): - self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) @@ -303,6 +299,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(channel.json_body.get("sid")) +class RegisterHideProfileTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/register" + + config = self.default_config() + config["enable_registration"] = True + config["show_users_in_user_directory"] = False + config["replicate_user_profiles_to"] = ["fakeserver"] + + mock_http_client = Mock(spec=["get_json", "post_json_get_json"]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + return self.hs + + def test_profile_hidden(self): + user_id = self.register_user("kermit", "monkey") + + post_json = self.hs.get_simple_http_client().post_json_get_json + + # We expect post_json_get_json to have been called twice: once with the original + # profile and once with the None profile resulting from the request to hide it + # from the user directory. + self.assertEqual(post_json.call_count, 2, post_json.call_args_list) + + # Get the args (and not kwargs) passed to post_json. + args = post_json.call_args[0] + # Make sure the last call was attempting to replicate profiles. + split_uri = args[0].split("/") + self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0]) + # Make sure the last profile update was overriding the user's profile to None. + self.assertEqual(args[1]["batch"][user_id], None, args[1]) + + class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ @@ -312,6 +349,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): sync.register_servlets, logout.register_servlets, account_validity.register_servlets, + account.register_servlets, ] def make_homeserver(self, reactor, clock): @@ -437,6 +475,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) +class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.client.v1.profile.register_servlets, + synapse.rest.client.v1.room.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Set accounts to expire after a week + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + config["replicate_user_profiles_to"] = "test.is" + + # Mock homeserver requests to an identity server + mock_http_client = Mock(spec=["post_json_get_json"]) + mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}")) + + self.hs = self.setup_test_homeserver( + config=config, simple_http_client=mock_http_client + ) + + return self.hs + + def test_expired_user_in_directory(self): + """Test that an expired user is hidden in the user directory""" + # Create an admin user to search the user directory + admin_id = self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + # Ensure the admin never expires + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": admin_id, + "expiration_ts": 999999999999, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Mock the homeserver's HTTP client + post_json = self.hs.get_simple_http_client().post_json_get_json + + # Create a user + username = "kermit" + user_id = self.register_user(username, "monkey") + self.login(username, "monkey") + self.get_success( + self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1) + ) + + # Check that a full profile for this user is replicated + self.assertIsNotNone(post_json.call_args, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + + self.assertIsNotNone(batch, batch) + self.assertEquals(len(batch), 1, batch) + + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's not None + replicated_content = batch[user_id] + self.assertIsNotNone(replicated_content) + + # Expire the user + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Wait for the background job to run which hides expired users in the directory + self.reactor.advance(60 * 60 * 1000) + + # Check if the homeserver has replicated the user's profile to the identity server + self.assertIsNotNone(post_json.call_args, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + + self.assertIsNotNone(batch, batch) + self.assertEquals(len(batch), 1, batch) + + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's None, signifying that the user should be removed from the user + # directory because they were expired + replicated_content = batch[user_id] + self.assertIsNone(replicated_content) + + # Now renew the user, and check they get replicated again to the identity server + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 99999999999, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + request, channel = self.make_request( + b"POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.pump(10) + self.reactor.advance(10) + self.pump() + + # Check if the homeserver has replicated the user's profile to the identity server + post_json = self.hs.get_simple_http_client().post_json_get_json + self.assertNotEquals(post_json.call_args, None, post_json.call_args) + payload = post_json.call_args[0][1] + batch = payload.get("batch") + self.assertNotEquals(batch, None, batch) + self.assertEquals(len(batch), 1, batch) + replicated_user_id = list(batch.keys())[0] + self.assertEquals(replicated_user_id, user_id, replicated_user_id) + + # There was replicated information about our user + # Check that it's not None, signifying that the user is back in the user + # directory + replicated_content = batch[user_id] + self.assertIsNotNone(replicated_content) + + class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): servlets = [ @@ -587,7 +774,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(request.code, 200, channel.result) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json(destination, path, ignore_backoff=False, **kwargs): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - def post_json(destination, path, data): + async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, "/_matrix/key/v2/query", diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py new file mode 100644
index 0000000000..2d021f6565 --- /dev/null +++ b/tests/rest/test_health.py
@@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.rest.health import HealthResource + +from tests import unittest + + +class HealthCheckTests(unittest.HomeserverTestCase): + def setUp(self): + super().setUp() + + # replace the JsonResource with a HealthResource. + self.resource = HealthResource() + + def test_health(self): + request, channel = self.make_request("GET", "/health", shorthand=False) + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rulecheck/__init__.py b/tests/rulecheck/__init__.py new file mode 100644
index 0000000000..a354d38ca8 --- /dev/null +++ b/tests/rulecheck/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py new file mode 100644
index 0000000000..1accc70dc9 --- /dev/null +++ b/tests/rulecheck/test_domainrulecheck.py
@@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json + +import synapse.rest.admin +from synapse.config._base import ConfigError +from synapse.rest.client.v1 import login, room +from synapse.rulecheck.domain_rule_checker import DomainRuleChecker + +from tests import unittest +from tests.server import make_request, render + + +class DomainRuleCheckerTestCase(unittest.TestCase): + def test_allowed(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + "domains_prevented_from_being_invited_to_published_rooms": ["target_two"], + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_one", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False + ) + ) + self.assertTrue( + check.user_may_invite( + "test:source_two", "test:target_two", None, "room", False + ) + ) + + # User can invite internal user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test1:target_one", None, "room", False, True + ) + ) + + # User can invite external user to a non-published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, False + ) + ) + + def test_disallowed(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + "source_four": [], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_one", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_three", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_two", "test:target_one", None, "room", False + ) + ) + self.assertFalse( + check.user_may_invite( + "test:source_four", "test:target_one", None, "room", False + ) + ) + + # User cannot invite external user to a published room + self.assertTrue( + check.user_may_invite( + "test:source_one", "test:target_two", None, "room", False, True + ) + ) + + def test_default_allow(self): + config = { + "default": True, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertTrue( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_default_deny(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + check = DomainRuleChecker(config) + self.assertFalse( + check.user_may_invite( + "test:source_three", "test:target_one", None, "room", False + ) + ) + + def test_config_parse(self): + config = { + "default": False, + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + }, + } + self.assertEquals(config, DomainRuleChecker.parse_config(config)) + + def test_config_parse_failure(self): + config = { + "domain_mapping": { + "source_one": ["target_one", "target_two"], + "source_two": ["target_two"], + } + } + self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config) + + +class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + hijack_auth = False + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["trusted_third_party_id_servers"] = ["localhost"] + + config["spam_checker"] = { + "module": "synapse.rulecheck.domain_rule_checker.DomainRuleChecker", + "config": { + "default": True, + "domain_mapping": {}, + "can_only_join_rooms_with_invite": True, + "can_only_create_one_to_one_rooms": True, + "can_only_invite_during_room_creation": True, + "can_invite_by_third_party_id": False, + }, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.admin_user_id = self.register_user("admin_user", "pass", admin=True) + self.admin_access_token = self.login("admin_user", "pass") + + self.normal_user_id = self.register_user("normal_user", "pass", admin=False) + self.normal_access_token = self.login("normal_user", "pass") + + self.other_user_id = self.register_user("other_user", "pass", admin=False) + + def test_admin_can_create_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + def test_normal_user_cannot_create_empty_room(self): + channel = self._create_room(self.normal_access_token) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_cannot_create_room_with_multiple_invites(self): + channel = self._create_room( + self.normal_access_token, + content={"invite": [self.other_user_id, self.admin_user_id]}, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly counts both normal and third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [self.other_user_id], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + # Test that it correctly rejects third party invites + channel = self._create_room( + self.normal_access_token, + content={ + "invite": [], + "invite_3pid": [{"medium": "email", "address": "foo@example.com"}], + }, + ) + assert channel.result["code"] == b"403", channel.result + + def test_normal_user_can_room_with_single_invites(self): + channel = self._create_room( + self.normal_access_token, content={"invite": [self.other_user_id]} + ) + assert channel.result["code"] == b"200", channel.result + + def test_cannot_join_public_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=403 + ) + + def test_can_join_invited_room(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + def test_cannot_invite(self): + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + def test_cannot_3pid_invite(self): + """Test that unbound 3pid invites get rejected. + """ + channel = self._create_room(self.admin_access_token) + assert channel.result["code"] == b"200", channel.result + + room_id = channel.json_body["room_id"] + + self.helper.invite( + room_id, + src=self.admin_user_id, + targ=self.normal_user_id, + tok=self.admin_access_token, + ) + + self.helper.join( + room_id, self.normal_user_id, tok=self.normal_access_token, expect_code=200 + ) + + self.helper.invite( + room_id, + src=self.normal_user_id, + targ=self.other_user_id, + tok=self.normal_access_token, + expect_code=403, + ) + + request, channel = self.make_request( + "POST", + "rooms/%s/invite" % (room_id), + {"address": "foo@bar.com", "medium": "email", "id_server": "localhost"}, + access_token=self.normal_access_token, + ) + self.render(request) + self.assertEqual(channel.code, 403, channel.result["body"]) + + def _create_room(self, token, content={}): + path = "/_matrix/client/r0/createRoom?access_token=%s" % (token,) + + request, channel = make_request( + self.hs.get_reactor(), + "POST", + path, + content=json.dumps(content).encode("utf8"), + ) + render(request, self.resource, self.hs.get_reactor()) + + return channel diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 99908edba3..3f88abe3d2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=1000) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) @@ -275,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) - token = self.get_success(self.event_source.get_current_token()) + token = self.event_source.get_current_token() events, _ = self.get_success( self.store.get_recent_events_for_room( room_id, limit=100, end_token=token.room_key diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..319e2c2325 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.db.runInteraction( + self.storage.db_pool.runInteraction( "test", - self.storage.db.simple_upsert_many_txn, + self.storage.db_pool.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage.db.simple_select_list( + self.storage.db_pool.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..1b516b7976 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -24,11 +24,11 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TestTransactionStore, self).__init__(database, db_conn, hs) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..2efbc97c2e 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py
@@ -9,7 +9,9 @@ from tests import unittest class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + self.updates = ( + self.hs.get_datastore().db_pool.updates + ) # type: BackgroundUpdater # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) @@ -29,7 +31,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() self.get_success( - store.db.simple_insert( + store.db_pool.simple_insert( "background_updates", values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, ) @@ -40,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def update(progress, count): yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield store.db.runInteraction( + yield store.db_pool.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index b589506c60..efcaeef1e7 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -21,7 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest @@ -57,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool self.datastore = SQLBaseStore(db, None, hs) @@ -66,7 +66,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( + yield self.datastore.db_pool.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -78,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_insert( + yield self.datastore.db_pool.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -93,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db.simple_select_one_onecol( + value = yield self.datastore.db_pool.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -107,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db.simple_select_one( + ret = yield self.datastore.db_pool.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -123,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db.simple_select_one( + ret = yield self.datastore.db_pool.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -138,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db.simple_select_list( + ret = yield self.datastore.db_pool.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -151,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( + yield self.datastore.db_pool.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -166,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_update_one( + yield self.datastore.db_pool.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -181,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db.simple_delete_one( + yield self.datastore.db_pool.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..3fab5a5248 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Make sure we don't clash with in progress updates. self.assertTrue( - self.store.db.updates._all_done, "Background updates are still ongoing" + self.store.db_pool.updates._all_done, "Background updates are still ongoing" ) schema_path = os.path.join( prepare_database.dir_path, - "data_stores", + "databases", "main", "schema", "delta", @@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "test_delete_forward_extremities", run_delta_file ) ) # Ugh, have to reset this flag - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) def test_soft_failed_extremities_handled_correctly(self): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..224ea6fd79 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -86,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -117,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -204,10 +204,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -225,7 +225,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store.db.simple_update( + self.store.db_pool.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -252,7 +252,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -263,14 +263,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # We should now get the correct result again @@ -293,10 +293,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) user_id = "@user:id" @@ -315,7 +315,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -341,7 +341,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store.db.simple_select_list( + self.store.db_pool.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..daac947cb2 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_room_to_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertEquals( @@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_alias_to_room(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - (yield self.store.get_association_from_room_alias(self.alias)), + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ), ) @defer.inlineCallbacks def test_delete_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) - room_id = yield self.store.delete_room_alias(self.alias) + room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (yield self.store.get_association_from_room_alias(self.alias)) + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..9f8d30373b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py
@@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.store_device("user", "device", "display_name") - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) - res = yield self.store.get_e2e_device_keys( - (("user1", "device1"), ("user2", "device2")) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) for i in range(0, 20): - self.get_success(self.store.db.runInteraction("insert", insert_event, i)) + self.get_success( + self.store.db_pool.runInteraction("insert", insert_event, i) + ) # this should get the last ten r = self.get_success(self.store.get_prev_events_for_room(room_id)) @@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for i in range(0, 20): self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room1) + self.store.db_pool.runInteraction("insert", insert_event, i, room1) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room2) + self.store.db_pool.runInteraction("insert", insert_event, i, room2) ) self.get_success( - self.store.db.runInteraction("insert", insert_event, i, room3) + self.store.db_pool.runInteraction("insert", insert_event, i, room3) ) # Test simple case @@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): depth = depth_map[event_id] - self.store.db.simple_insert_txn( + self.store.db_pool.simple_insert_txn( txn, table="events", values={ @@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) - self.store.db.simple_insert_many_txn( + self.store.db_pool.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): for event_id in auth_graph: next_stream_ordering += 1 self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "insert", insert_event, event_id, next_stream_ordering ) ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..857db071d4 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks @@ -56,7 +60,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.db.runInteraction( + counts = yield self.store.db_pool.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -72,10 +76,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + yield defer.ensureDeferred( + self.store.add_push_actions_to_staging( + event.event_id, {user_id: action} + ) ) - yield self.store.db.runInteraction( + yield self.store.db_pool.runInteraction( "", self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -83,12 +89,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return self.store.db.runInteraction( + return self.store.db_pool.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.db.runInteraction( + return self.store.db_pool.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -117,7 +123,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db.simple_delete( + yield self.store.db_pool.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -136,7 +142,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db.simple_insert( + return self.store.db_pool.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..e845410dae 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@ # limitations under the License. -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.unittest import HomeserverTestCase @@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db = self.store.db # type: Database + self.db_pool = self.store.db_pool # type: DatabasePool - self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) def _setup_db(self, txn): txn.execute("CREATE SEQUENCE foobar_seq") @@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def _create(conn): return MultiWriterIdGenerator( conn, - self.db, + self.db_pool, instance_name=instance_name, table="foobar", instance_column="instance_name", @@ -55,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): sequence_name="foobar_seq", ) - return self.get_success(self.db.runWithConnection(_create)) + return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): def _insert(txn): @@ -65,7 +65,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): (instance_name,), ) - self.get_success(self.db.runInteraction("test_single_instance", _insert)) + self.get_success(self.db_pool.runInteraction("test_single_instance", _insert)) def test_empty(self): """Test an ID generator against an empty database gives sensible @@ -178,7 +178,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token("master"), 7) - self.get_success(self.db.runInteraction("test", _get_next_txn)) + self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..0155ffd04e 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase): def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") yield self.store.create_profile(self.user.localpart) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield self.store.set_profile_displayname( + self.user.localpart, self.displayname, 1 + ) users, total = yield self.store.get_users_paginate( 0, 10, name="bc", guests=False diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..e793781a26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX why are we doing this here? this function is only run at startup # so it is odd to re-run it here. self.get_success( - self.store.db.runInteraction( + self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) ) @@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ] self.hs.config.mau_limits_reserved_threepids = threepids - d = self.store.db.runInteraction( + d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.get_success(d) @@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..7458a37e54 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -33,9 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) - - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1) self.assertEquals( "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) @@ -43,10 +41,8 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.localpart, "http://my.site/here", 1 ) self.assertEquals( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..a6012c973d 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py
@@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred( + storage.purge_events.purge_history(self.room_id, event, True) + ) self.pump() self.assertEqual(self.successResultOf(purge), None) @@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..41511d479f 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): - built_event = yield self._base_builder.build(prev_event_ids) + built_event = yield defer.ensureDeferred( + self._base_builder.build(prev_event_ids) + ) built_event._event_id = self._event_id built_event._dict["event_id"] = self._event_id @@ -341,7 +343,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -359,7 +361,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store.db.simple_select_one_onecol( + self.store.db_pool.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1d77b4a2d6..d07b985a8e 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -88,17 +90,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f282921538..17c9da4838 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -179,10 +179,10 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Now let's create a room, which will insert a membership @@ -192,7 +192,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store.db.simple_insert( + self.store.db_pool.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -203,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store.db.updates._all_done = False + self.store.db_pool.updates._all_done = False # Now let's actually drive the updates to completion while not self.get_success( - self.store.db.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a0e133cd4a..8bd12fa847 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -68,7 +70,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..ecfafe68a9 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): def test_search_user_dir(self): # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) self.assertDictEqual( @@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): def test_search_user_dir_all_users(self): self.hs.config.user_directory_search_all_users = True try: - r = yield self.store.search_user_dir(ALICE, "bob", 10) + r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(2, len(r["results"])) self.assertDictEqual( diff --git a/tests/test_federation.py b/tests/test_federation.py
index 87a16d7d7a..c2f12c2741 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} diff --git a/tests/test_server.py b/tests/test_server.py
index 073b2362cc..d628070e48 100644 --- a/tests/test_server.py +++ b/tests/test_server.py
@@ -157,6 +157,29 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def test_head_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + + def _callback(request, **kwargs): + return 200, {"result": True} + + res = JsonResource(self.homeserver) + res.register_paths( + "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet", + ) + + # The path was registered as GET, but this is a HEAD request. + request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + self.assertNotIn("body", channel.result) + self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"]) + class OptionsResourceTests(unittest.TestCase): def setUp(self): @@ -255,7 +278,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.reactor = ThreadedMemoryReactorClock() def test_good_response(self): - def callback(request): + async def callback(request): request.write(b"response") request.finish() @@ -275,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): with the right location. """ - def callback(request, **kwargs): + async def callback(request, **kwargs): raise RedirectException(b"/look/an/eagle", 301) res = WrapHtmlRequestHandlerTests.TestResource() @@ -295,7 +318,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): returned too """ - def callback(request, **kwargs): + async def callback(request, **kwargs): e = RedirectException(b"/no/over/there", 304) e.cookies.append(b"session=yespls") raise e @@ -312,3 +335,19 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(location_headers, [b"/no/over/there"]) cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) + + def test_head_request(self): + """A head request should work by being turned into a GET request.""" + + async def callback(request): + request.write(b"response") + request.finish() + + res = WrapHtmlRequestHandlerTests.TestResource() + res.callback = callback + + request, channel = make_request(self.reactor, b"HEAD", b"/path") + render(request, res, self.reactor) + + self.assertEqual(channel.result["code"], b"200") + self.assertNotIn("body", channel.result) diff --git a/tests/test_state.py b/tests/test_state.py
index 4858e8fc59..b5c3667d2a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) @@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase): self.state.compute_event_context(event, old_state=old_state) ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) @@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase): self.state.compute_event_context(event, old_state=old_state) ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) @@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..d4a722a30f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py
@@ -12,9 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import string_types from synapse.api.errors import SynapseError -from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart +from synapse.types import ( + GroupID, + RoomAlias, + UserID, + map_username_to_mxid_localpart, + strip_invalid_mxid_characters, +) from tests import unittest @@ -103,3 +110,16 @@ class MapUsernameTestCase(unittest.TestCase): self.assertEqual( map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" ) + + +class StripInvalidMxidCharactersTestCase(unittest.TestCase): + def test_return_type(self): + unstripped = strip_invalid_mxid_characters("test") + stripped = strip_invalid_mxid_characters("test@") + + self.assertTrue(isinstance(unstripped, string_types), type(unstripped)) + self.assertTrue(isinstance(stripped, string_types), type(stripped)) + + def test_strip(self): + stripped = strip_invalid_mxid_characters("test@") + self.assertEqual(stripped, "test", stripped) diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index b371efc0df..531a9b9118 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) for i in range(0, len(events_to_filter)): @@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -265,8 +271,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): storage.main = test_store storage.state = test_store - filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) ) logger.info("Filtering took %f seconds", time.time() - start) diff --git a/tests/unittest.py b/tests/unittest.py
index 68d2586efd..d0bba3ddef 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: - def get_user_by_access_token(token=None, allow_guest=False): - return succeed( - { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - ) - - def get_user_by_req(request, allow_guest=False, rights="access"): - return succeed( - create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + async def get_user_by_req(request, allow_guest=False, rights="access"): + return create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -422,8 +418,8 @@ class HomeserverTestCase(TestCase): async def run_bg_updates(): with LoggingContext("run_bg_updates", request="run_bg_updates-1"): - while not await stor.db.updates.has_completed_background_updates(): - await stor.db.updates.do_next_background_update(1) + while not await stor.db_pool.updates.has_completed_background_updates(): + await stor.db_pool.updates.do_next_background_update(1) hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -571,7 +567,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore().db.simple_insert( + self.hs.get_datastore().db_pool.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e348694ad..bc42ffce88 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py
@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self): """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) # advance the clock a bit before making the request self.pump(1) @@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase): with limiter: pass - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) def test_limiter(self): """General test case which walks through the process of a failing request""" store = self.hs.get_datastore() - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], failure_ts) self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) # now if we try again we should get a failure - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - self.failureResultOf(d, NotRetryingDestination) + self.get_failure( + get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + ) # # advance the clock and try again # self.pump(MIN_RETRY_INTERVAL) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) try: @@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase): except AssertionError: pass - # wait for the update to land - self.pump() - - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], retry_ts) self.assertGreaterEqual( @@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase): # one more go, with success # self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) - d = get_retry_limiter("test_dest", self.clock, store) - self.pump() - limiter = self.successResultOf(d) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) with limiter: @@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase): # wait for the update to land self.pump() - d = store.get_destination_retry_timings("test_dest") - self.pump() - new_timings = self.successResultOf(d) + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) diff --git a/tests/utils.py b/tests/utils.py
index ac643679aa..d543f3ed32 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -154,6 +154,10 @@ def default_config(name, parse=False): "account": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins": { + "local": {"per_second": 10000, "burst_count": 10000}, + "remote": {"per_second": 10000, "burst_count": 10000}, + }, "saml2_enabled": False, "public_baseurl": None, "default_identity_server": None, @@ -169,6 +173,8 @@ def default_config(name, parse=False): "update_user_directory": False, "caches": {"global_factor": 1}, "listeners": [{"port": 0, "type": "http"}], + # Enable encryption by default in private rooms + "encryption_enabled_by_default_for_room_type": "invite", } if parse: @@ -638,14 +644,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -653,7 +653,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -671,8 +671,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield defer.ensureDeferred( - event_creation_handler.create_new_client_event(builder) - ) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context)