diff options
author | Brendan Abolivier <babolivier@matrix.org> | 2019-05-14 11:43:03 +0100 |
---|---|---|
committer | Brendan Abolivier <babolivier@matrix.org> | 2019-05-14 11:43:03 +0100 |
commit | f608ddbe5c58f35b76c5b1199c60a6c17a684d19 (patch) | |
tree | cfc254400db8c0d3d9cdcef7942e935b09ba4f1d /tests | |
parent | Merge pull request #5115 from matrix-org/babolivier/lookup_path (diff) | |
parent | 0.99.4rc1 (diff) | |
download | synapse-f608ddbe5c58f35b76c5b1199c60a6c17a684d19.tar.xz |
Merge branch 'release-v0.99.4' into dinsic dinsic_2019-05-14
Diffstat (limited to 'tests')
66 files changed, 1320 insertions, 1161 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2a7044801a..6ba623de13 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -109,7 +109,6 @@ class FilteringTestCase(unittest.TestCase): "event_format": "client", "event_fields": ["type", "content", "sender"], }, - # a single backslash should be permitted (though it is debatable whether # it should be permitted before anything other than `.`, and what that # actually means) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 30a255d441..dbdd427cac 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -10,19 +10,19 @@ class TestRatelimiter(unittest.TestCase): key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) - self.assertEquals(10., time_allowed) + self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) - self.assertEquals(10., time_allowed) + self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) - self.assertEquals(20., time_allowed) + self.assertEquals(20.0, time_allowed) def test_pruning(self): limiter = Ratelimiter() diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 590abc1e92..48792d1480 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -25,16 +25,18 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FederationReaderServer, + http_client=None, homeserverToUse=FederationReaderServer ) return hs - @parameterized.expand([ - (["federation"], "auth_fail"), - ([], "no_resource"), - (["openid", "federation"], "auth_fail"), - (["openid"], "auth_fail"), - ]) + @parameterized.expand( + [ + (["federation"], "auth_fail"), + ([], "no_resource"), + (["openid", "federation"], "auth_fail"), + (["openid"], "auth_fail"), + ] + ) def test_openid_listener(self, names, expectation): """ Test different openid listener configurations. @@ -53,17 +55,14 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = ( - site.resource.children[b"_matrix"].children[b"federation"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise request, channel = self.make_request( - "GET", - "/_matrix/federation/v1/openid/userinfo", + "GET", "/_matrix/federation/v1/openid/userinfo" ) self.render(request) @@ -74,16 +73,18 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer, + http_client=None, homeserverToUse=SynapseHomeServer ) return hs - @parameterized.expand([ - (["federation"], "auth_fail"), - ([], "no_resource"), - (["openid", "federation"], "auth_fail"), - (["openid"], "auth_fail"), - ]) + @parameterized.expand( + [ + (["federation"], "auth_fail"), + ([], "no_resource"), + (["openid", "federation"], "auth_fail"), + (["openid"], "auth_fail"), + ] + ) def test_openid_listener(self, names, expectation): """ Test different openid listener configurations. @@ -102,17 +103,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = ( - site.resource.children[b"_matrix"].children[b"federation"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise request, channel = self.make_request( - "GET", - "/_matrix/federation/v1/openid/userinfo", + "GET", "/_matrix/federation/v1/openid/userinfo" ) self.render(request) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 795b4c298d..5017cbce85 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -45,13 +45,7 @@ class ConfigGenerationTestCase(unittest.TestCase): ) self.assertSetEqual( - set( - [ - "homeserver.yaml", - "lemurs.win.log.config", - "lemurs.win.signing.key", - ] - ), + set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), set(os.listdir(self.dir)), ) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index 47fffcfeb2..0ec10019b3 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -22,7 +22,8 @@ from tests import unittest class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self): - config = yaml.safe_load(""" + config = yaml.safe_load( + """ alias_creation_rules: - user_id: "*bob*" alias: "*" @@ -38,43 +39,49 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): action: "allow" room_list_publication_rules: [] - """) + """ + ) rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.assertFalse(rd_config.is_alias_creation_allowed( - user_id="@bob:example.com", - room_id="!test", - alias="#test:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@test:example.com", - room_id="!test", - alias="#unofficial_st:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@foobar:example.com", - room_id="!test", - alias="#test:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@gah:example.com", - room_id="!test", - alias="#goo:example.com", - )) - - self.assertFalse(rd_config.is_alias_creation_allowed( - user_id="@test:example.com", - room_id="!test", - alias="#test:example.com", - )) + self.assertFalse( + rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", room_id="!test", alias="#test:example.com" + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + room_id="!test", + alias="#unofficial_st:example.com", + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + room_id="!test", + alias="#test:example.com", + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", room_id="!test", alias="#goo:example.com" + ) + ) + + self.assertFalse( + rd_config.is_alias_creation_allowed( + user_id="@test:example.com", room_id="!test", alias="#test:example.com" + ) + ) def test_room_publish_acl(self): - config = yaml.safe_load(""" + config = yaml.safe_load( + """ alias_creation_rules: [] room_list_publication_rules: @@ -92,55 +99,66 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): action: "allow" - room_id: "!test-deny" action: "deny" - """) + """ + ) rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@bob:example.com", - room_id="!test", - aliases=["#test:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#unofficial_st:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@foobar:example.com", - room_id="!test", - aliases=[], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@gah:example.com", - room_id="!test", - aliases=["#goo:example.com"], - )) - - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#test:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@foobar:example.com", - room_id="!test-deny", - aliases=[], - )) - - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@gah:example.com", - room_id="!test-deny", - aliases=[], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#unofficial_st:example.com", "#blah:example.com"], - )) + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@bob:example.com", + room_id="!test", + aliases=["#test:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#unofficial_st:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@foobar:example.com", room_id="!test", aliases=[] + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@gah:example.com", + room_id="!test", + aliases=["#goo:example.com"], + ) + ) + + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#test:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@foobar:example.com", room_id="!test-deny", aliases=[] + ) + ) + + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@gah:example.com", room_id="!test-deny", aliases=[] + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#unofficial_st:example.com", "#blah:example.com"], + ) + ) diff --git a/tests/config/test_server.py b/tests/config/test_server.py index f5836d73ac..de64965a60 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -19,7 +19,6 @@ from tests import unittest class ServerConfigTestCase(unittest.TestCase): - def test_is_threepid_reserved(self): user1 = {'medium': 'email', 'address': 'user1@example.com'} user2 = {'medium': 'email', 'address': 'user2@example.com'} diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index c260d3359f..40ca428778 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -26,7 +26,6 @@ class TestConfig(TlsConfig): class TLSConfigTests(TestCase): - def test_warn_self_signed(self): """ Synapse will give a warning when it loads a self-signed certificate. @@ -34,7 +33,8 @@ class TLSConfigTests(TestCase): config_dir = self.mktemp() os.mkdir(config_dir) with open(os.path.join(config_dir, "cert.pem"), 'w') as f: - f.write("""-----BEGIN CERTIFICATE----- + f.write( + """-----BEGIN CERTIFICATE----- MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb @@ -56,11 +56,12 @@ I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2 SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= ------END CERTIFICATE-----""") +-----END CERTIFICATE-----""" + ) config = { "tls_certificate_path": os.path.join(config_dir, "cert.pem"), - "tls_fingerprints": [] + "tls_fingerprints": [], } t = TestConfig() @@ -75,5 +76,5 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= "Self-signed TLS certificates will not be accepted by " "Synapse 1.0. Please either provide a valid certificate, " "or use Synapse's ACME support to provision one." - ) + ), ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f5bd7a1aa1..3c79d4afe7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -169,7 +169,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1, )] + [("server10", json1)] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) @@ -345,6 +345,7 @@ def _verify_json_for_server(keyring, server_name, json_object): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ + @defer.inlineCallbacks def v(): rv1 = yield keyring.verify_json_for_server(server_name, json_object) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 28e7e27416..7bb106b5f7 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -33,11 +33,15 @@ class FederationSenderTestCases(HomeserverTestCase): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -46,21 +50,24 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but @@ -68,11 +75,15 @@ class FederationSenderTestCases(HomeserverTestCase): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -81,25 +92,30 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) mock_send_transaction.reset_mock() # send the second RR - receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() mock_send_transaction.assert_not_called() @@ -111,18 +127,21 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['other_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5b2105bc76..917548bb31 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -115,11 +115,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): # We cheekily override the config to add custom alias creation rules config = {} config["alias_creation_rules"] = [ - { - "user_id": "*", - "alias": "#unofficial_*", - "action": "allow", - } + {"user_id": "*", "alias": "#unofficial_*", "action": "allow"} ] config["room_list_publication_rules"] = [] @@ -162,9 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -179,10 +173,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -191,10 +182,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) @@ -202,9 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 1c49bbbc3c..2e72a1dd23 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -36,7 +36,7 @@ room_keys = { "first_message_index": 1, "forwarded_count": 1, "is_verified": False, - "session_data": "SSBBTSBBIEZJU0gK" + "session_data": "SSBBTSBBIEZJU0gK", } } } @@ -47,15 +47,13 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer + self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, - handlers=None, - replication_layer=mock.Mock(), + self.addCleanup, handlers=None, replication_layer=mock.Mock() ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) self.local_user = "@boris:" + self.hs.hostname @@ -88,67 +86,86 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # upload a new one... - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(res, "2") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "2", - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "2", + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) @defer.inlineCallbacks def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - res = yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + res = yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) @defer.inlineCallbacks def test_update_missing_version(self): @@ -156,11 +173,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version(self.local_user, "1", { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1" - }) + yield self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -170,29 +191,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 400 if the version in the body is missing or doesn't match """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) @@ -223,10 +252,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can delete it @@ -255,16 +284,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = yield self.handler.get_room_keys(self.local_user, version) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest @@ -275,7 +302,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + yield self.handler.upload_room_keys( + self.local_user, "no_version", room_keys + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -285,10 +314,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None @@ -304,16 +333,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(version, "2") res = None @@ -327,10 +359,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -340,18 +372,13 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check getting room_keys for a given room res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org" + self.local_user, version, room_id="!abc:matrix.org" ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) self.assertDictEqual(res, room_keys) @@ -359,10 +386,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -378,7 +405,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "SSBBTSBBIEZJU0gK" + "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it @@ -387,8 +414,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # test that a session with a higher forwarded_count doesn't replace one @@ -399,8 +425,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # TODO: check edge cases as well as the common variations here @@ -409,56 +434,36 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") # check for bulk-delete yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys(self.local_user, version) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", + self.local_user, version, room_id="!abc:matrix.org" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 94c6080e34..f70c6e7d65 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -424,8 +424,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, - federation_sender=Mock(), + "server", http_client=None, federation_sender=Mock() ) return hs @@ -457,7 +456,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test2 as online, test will be offline with a last_active of 0 self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -506,13 +505,13 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test as online self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) # Add servers to the room @@ -541,8 +540,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), - states=[expected_state] + destinations=set(("server2", "server3")), states=[expected_state] ) def _add_new_user(self, room_id, user_id): @@ -565,7 +563,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): type=EventTypes.Member, sender=user_id, state_key=user_id, - content={"membership": Membership.JOIN} + content={"membership": Membership.JOIN}, ) prev_event_ids = self.get_success( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5a0b6c201c..cb8b4d2913 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,20 +64,22 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) hs = self.setup_test_homeserver( - datastore=(Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - )), + datastore=( + Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + ] + ) + ), notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring, @@ -87,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): # the tests assume that we are starting at unix time 1000 - reactor.pump((1000, )) + reactor.pump((1000,)) mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -114,6 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + hs.get_auth().check_joined_room = check_joined_room def get_joined_hosts_for_room(room_id): @@ -123,6 +126,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def get_current_users_in_room(room_id): return set(str(u) for u in self.room_members) + hs.get_state_handler().get_current_users_in_room = get_current_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( @@ -141,21 +145,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -170,12 +169,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) + ) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -216,14 +214,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -247,14 +241,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.stopped_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( "farm", @@ -274,18 +268,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) def test_typing_timeout(self): @@ -293,22 +279,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -320,45 +301,30 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - self.reactor.pump([16, ]) + self.reactor.pump([16]) - self.on_new_event.assert_has_calls( - [call('typing_key', 2, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=1 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) # SYN-230 - see if we can still set after timeout - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 3, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f1d0aa42b6..44468f5382 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,8 +14,9 @@ # limitations under the License. from mock import Mock +import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import user_directory from synapse.storage.roommember import ProfileInfo @@ -29,7 +30,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, ] @@ -327,7 +328,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): user_directory.register_servlets, room.register_servlets, login.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, ] def make_homeserver(self, reactor, clock): @@ -351,9 +352,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Assert user directory is not empty request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -362,9 +361,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index ee8010f598..851fc0eb33 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -24,14 +24,12 @@ def get_test_cert_file(): # # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \ # -nodes -subj '/CN=testserv' - return os.path.join( - os.path.dirname(__file__), - 'server.pem', - ) + return os.path.join(os.path.dirname(__file__), 'server.pem') class ServerTLSContext(object): """A TLS Context which presents our test cert.""" + def __init__(self): self.filename = get_test_cert_file() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index dcf184d3cf..7036615041 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -39,6 +39,7 @@ from synapse.util.logcontext import LoggingContext from tests.http import ServerTLSContext from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase +from tests.utils import default_config logger = logging.getLogger(__name__) @@ -53,7 +54,7 @@ class MatrixFederationAgentTests(TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(None), + tls_client_options_factory=ClientTLSOptionsFactory(default_config("test")), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, @@ -78,12 +79,12 @@ class MatrixFederationAgentTests(TestCase): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol), + FakeTransport(server_tls_protocol, self.reactor, client_protocol) ) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol), + FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) # give the reactor a pump to get the TLS juices flowing. @@ -124,7 +125,7 @@ class MatrixFederationAgentTests(TestCase): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={}, + self, client_factory, expected_sni, content, response_headers={} ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -138,8 +139,7 @@ class MatrixFederationAgentTests(TestCase): """ # make the connection for .well-known well_known_server = self._make_connection( - client_factory, - expected_sni=expected_sni, + client_factory, expected_sni=expected_sni ) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) @@ -153,17 +153,14 @@ class MatrixFederationAgentTests(TestCase): """ self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) def test_get(self): """ @@ -183,18 +180,14 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b"testserv", - ) + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] ) content = request.content.read() self.assertEqual(content, b'') @@ -243,19 +236,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'1.2.3.4'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) # finish the request request.finish() @@ -284,19 +271,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) # finish the request request.finish() @@ -325,19 +306,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 80) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]:80'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) # finish the request request.finish() @@ -376,7 +351,7 @@ class MatrixFederationAgentTests(TestCase): # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -386,19 +361,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -426,13 +395,14 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -443,8 +413,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -452,8 +421,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -489,8 +457,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) redirect_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) # send a 302 redirect @@ -499,7 +466,7 @@ class MatrixFederationAgentTests(TestCase): request.redirect(b'https://testserv/even_better_known') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # now there should be another connection clients = self.reactor.tcpClients @@ -509,8 +476,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) well_known_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) self.assertEqual(len(well_known_server.requests), 1, "No request after 302") @@ -520,11 +486,11 @@ class MatrixFederationAgentTests(TestCase): request.write(b'{ "m.server": "target-server" }') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -535,8 +501,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -544,8 +509,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -584,12 +548,12 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON', + client_factory, expected_sni=b"testserv", content=b'NOT JSON' ) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -599,19 +563,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -634,7 +592,7 @@ class MatrixFederationAgentTests(TestCase): # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # Make sure treq is trying to connect @@ -645,19 +603,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -684,17 +636,18 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443), + Server(host=b"srvtarget", port=8443) ] self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target of the SRV record @@ -705,8 +658,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -714,8 +666,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -756,7 +707,7 @@ class MatrixFederationAgentTests(TestCase): # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # We should fall back to port 8448 @@ -768,8 +719,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -777,8 +727,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -800,7 +749,7 @@ class MatrixFederationAgentTests(TestCase): self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # Make sure treq is trying to connect @@ -812,8 +761,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -821,8 +769,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -896,67 +843,70 @@ class TestCachePeriodFromHeaders(TestCase): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}), - ), 100, + Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + ), + 100, ) # missing value - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=, bar']}), - )) + self.assertIsNone( + _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + ) # hackernews: bogus due to semicolon - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}), - )) + self.assertIsNone( + _cache_period_from_headers( + Headers({b'Cache-Control': [b'private; max-age=0']}) + ) + ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}), - ), 0, + Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + ), + 0, ) # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}), - ), 0, + Headers({b'cache-control': [b'private, max-age=0']}) + ), + 0, ) def test_expires(self): self.assertEqual( _cache_period_from_headers( Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), - time_now=lambda: 1548833700 - ), 33, + time_now=lambda: 1548833700, + ), + 33, ) # cache-control overrides expires self.assertEqual( _cache_period_from_headers( - Headers({ - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'] - }), - time_now=lambda: 1548833700 - ), 10, + Headers( + { + b'cache-control': [b'max-age=10'], + b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + } + ), + time_now=lambda: 1548833700, + ), + 10, ) # invalid expires means immediate expiry - self.assertEqual( - _cache_period_from_headers( - Headers({b'Expires': [b'0']}), - ), 0, - ) + self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) def _check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _build_test_server(): @@ -972,7 +922,7 @@ def _build_test_server(): server_factory.log = _log_request server_tls_factory = TLSMemoryBIOFactory( - ServerTLSContext(), isClient=False, wrappedFactory=server_factory, + ServerTLSContext(), isClient=False, wrappedFactory=server_factory ) return server_tls_factory.buildProtocol(None) @@ -986,6 +936,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) class TrustingTLSPolicyForHTTPS(object): """An IPolicyForHTTPS which doesn't do any certificate verification""" + def creatorForNetloc(self, hostname, port): certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index a872e2441e..034c0db8d2 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -68,9 +68,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) - result_deferred.callback( - ([answer_srv], None, None) - ) + result_deferred.callback(([answer_srv], None, None)) servers = self.successResultOf(test_d) @@ -112,7 +110,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver( - dns_client=dns_client_mock, cache=cache, get_time=clock.time, + dns_client=dns_client_mock, cache=cache, get_time=clock.time ) servers = yield resolver.resolve_service(service_name) @@ -168,11 +166,13 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # returning a single "." should make the lookup fail with a ConenctError - lookup_deferred.callback(( - [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], - None, - None, - )) + lookup_deferred.callback( + ( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + ) + ) self.failureResultOf(resolve_d, ConnectError) @@ -191,14 +191,16 @@ class SrvResolverTestCase(unittest.TestCase): resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) - lookup_deferred.callback(( - [ - dns.RRHeader(type=dns.A, payload=dns.Record_A()), - dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), - ], - None, - None, - )) + lookup_deferred.callback( + ( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + ) + ) servers = self.successResultOf(resolve_d) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index cd8e086f86..279e456614 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -36,9 +36,7 @@ from tests.unittest import HomeserverTestCase def check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): @@ -54,6 +52,7 @@ class FederationClientTests(HomeserverTestCase): """ happy-path test of a GET request """ + @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: @@ -175,8 +174,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance( - f.value.inner_exception, - (ConnectingCancelledError, TimeoutError), + f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) def test_client_connect_no_response(self): @@ -216,9 +214,7 @@ class FederationClientTests(HomeserverTestCase): Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( - method="GET", - destination="testserv:8008", - path="foo/bar", + method="GET", destination="testserv:8008", path="foo/bar" ) d = self.cl._send_request(request, timeout=10000) @@ -258,8 +254,10 @@ class FederationClientTests(HomeserverTestCase): # Send it the HTTP response client.dataReceived( - (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n") + ( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" + ) ) # Push by enough to time it out @@ -274,9 +272,7 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -329,9 +325,7 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -368,10 +362,7 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json( - "testserv:8008", "foo/bar", timeout=10000, - data={"a": "b"} - ) + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index 0f613945c8..ee0add3455 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -45,7 +45,9 @@ def do_patch(): except Exception: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s on exception" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -54,7 +56,9 @@ def do_patch(): if not isinstance(res, Deferred) or res.called: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -66,9 +70,7 @@ def do_patch(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % ( - f, LoggingContext.current_context(), start_context, - ) + ) % (f, LoggingContext.current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) @@ -76,7 +78,9 @@ def do_patch(): if LoggingContext.current_context() != start_context: err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", - f, start_context, LoggingContext.current_context(), + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index be3fed8de3..325ea449ae 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -19,7 +19,8 @@ import pkg_resources from twisted.internet.defer import Deferred -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from tests.unittest import HomeserverTestCase @@ -33,7 +34,7 @@ class EmailPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 6dc45e8506..13bd2c8688 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -17,7 +17,8 @@ from mock import Mock from twisted.internet.defer import Deferred -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.util.logcontext import make_deferred_yieldable from tests.unittest import HomeserverTestCase @@ -32,7 +33,7 @@ class HTTPPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 1f72a2a04f..104349cdbd 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -74,21 +74,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( master_result, expected_result, - "Expected master result to be %r but was %r" % ( - expected_result, master_result - ), + "Expected master result to be %r but was %r" + % (expected_result, master_result), ) self.assertEqual( slaved_result, expected_result, - "Expected slave result to be %r but was %r" % ( - expected_result, slaved_result - ), + "Expected slave result to be %r but was %r" + % (expected_result, slaved_result), ) self.assertEqual( master_result, slaved_result, - "Slave result %r does not match master result %r" % ( - slaved_result, master_result - ), + "Slave result %r does not match master result %r" + % (slaved_result, master_result), ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 65ecff3bd6..a368117b43 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,10 +234,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([ - (j2, j2ctx), - (msg, msgctx), - ])) + self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) self.replicate() event_source = RoomEventSource(self.hs) @@ -257,15 +254,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # # First, we get a list of the rooms we are joined to joined_rooms = self.get_success( - self.slaved_store.get_rooms_for_user_with_stream_ordering( - USER_ID_2, - ), + self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) ) # Then, we get a list of the events since the last sync membership_changes = self.get_success( self.slaved_store.get_membership_changes_for_user( - USER_ID_2, prev_token, current_token, + USER_ID_2, prev_token, current_token ) ) @@ -298,9 +293,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.master_store.persist_events([(event, context)], backfilled=True) ) else: - self.get_success( - self.master_store.persist_event(event, context) - ) + self.get_success(self.master_store.persist_event(event, context)) return event @@ -359,9 +352,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context( - event - )) + context = self.get_success(state_handler.compute_event_context(event)) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 38b368a972..ce3835ae6a 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -22,6 +22,7 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) @@ -52,6 +53,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): self.received_rdata_rows = [] @@ -69,6 +71,4 @@ class TestReplicationClientHandler(object): def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append( - (stream_name, token, r) - ) + self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/rest/admin/__init__.py b/tests/rest/admin/__init__.py new file mode 100644 index 0000000000..1453d04571 --- /dev/null +++ b/tests/rest/admin/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/admin/test_admin.py index c00ef21d75..ee5f09041f 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -19,50 +19,37 @@ import json from mock import Mock +import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.rest.client.v1 import admin, events, login, room +from synapse.http.server import JsonResource +from synapse.rest.admin import VersionServlet +from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v2_alpha import groups from tests import unittest class VersionTestCase(unittest.HomeserverTestCase): + url = '/_synapse/admin/v1/server_version' - servlets = [ - admin.register_servlets, - login.register_servlets, - ] - - url = '/_matrix/client/r0/admin/server_version' + def create_test_json_resource(self): + resource = JsonResource(self.hs) + VersionServlet(self.hs).register(resource) + return resource def test_version_string(self): - self.register_user("admin", "pass", admin=True) - self.admin_token = self.login("admin", "pass") - - request, channel = self.make_request("GET", self.url, - access_token=self.admin_token) + request, channel = self.make_request("GET", self.url, shorthand=False) self.render(request) - self.assertEqual(200, int(channel.result["code"]), - msg=channel.result["body"]) - self.assertEqual({'server_version', 'python_version'}, - set(channel.json_body.keys())) - - def test_inaccessible_to_non_admins(self): - self.register_user("unprivileged-user", "pass", admin=False) - user_token = self.login("unprivileged-user", "pass") - - request, channel = self.make_request("GET", self.url, - access_token=user_token) - self.render(request) - - self.assertEqual(403, int(channel.result['code']), - msg=channel.result['body']) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + {'server_version', 'python_version'}, set(channel.json_body.keys()) + ) class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): @@ -213,9 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin" - ) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -343,11 +328,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Invalid user_type - body = json.dumps({ - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid"} + body = json.dumps( + { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } ) request, channel = self.make_request("POST", self.url, body.encode('utf8')) self.render(request) @@ -358,7 +345,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class ShutdownRoomTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, events.register_servlets, room.register_servlets, @@ -370,9 +357,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): hs.config.user_consent_version = "1" consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = ( - "http://example.com" - ) + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder self.store = hs.get_datastore() @@ -384,9 +369,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.other_user_token = self.login("user", "pass") # Mark the admin user as having consented - self.get_success( - self.store.user_set_consent_version(self.admin_user, "1"), - ) + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not @@ -398,9 +381,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) # Assert one user in room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([self.other_user], users_in_room) # Enable require consent to send events @@ -408,8 +389,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - room_id, - body="foo", tok=self.other_user_token, expect_code=403, + room_id, body="foo", tok=self.other_user_token, expect_code=403 ) # Test that the admin can still send shutdown @@ -425,9 +405,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Assert there is now no longer anyone in the room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) @unittest.DEBUG @@ -472,30 +450,26 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) class DeleteGroupTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, groups.register_servlets, ] @@ -515,15 +489,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "POST", "/create_group".encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] @@ -533,27 +503,17 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.admin_user_tok, - content={} + "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.other_user_token, - content={} + "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -565,15 +525,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "POST", url.encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check group returns 404 self._check_group(group_id, expect_code=404) @@ -589,28 +545,22 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", - "/joined_groups".encode('ascii'), - access_token=access_token, + "GET", "/joined_groups".encode('ascii'), access_token=access_token ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 4294bbec2a..5528971190 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -15,8 +15,9 @@ import os +import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.consent import consent_resource from tests import unittest @@ -31,7 +32,7 @@ except Exception: class ConsentResourceTestCase(unittest.HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ec260a2873..36bb19a24c 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -19,7 +19,8 @@ from mock import Mock from twisted.internet import defer -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account from tests import unittest @@ -30,7 +31,7 @@ class IdentityDisabledTestCase(unittest.HomeserverTestCase): servlets = [ account.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -109,7 +110,7 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): servlets = [ account.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -154,9 +155,7 @@ class IdentityEnabledTestCase(unittest.HomeserverTestCase): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ( - "/rooms/%s/invite" % (room_id) - ).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') request, channel = self.make_request( b"POST", request_url, request_data, access_token=self.tok, ) diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py new file mode 100644 index 0000000000..73c5b44b46 --- /dev/null +++ b/tests/rest/client/v1/test_directory.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.rest import admin +from synapse.rest.client.v1 import directory, login, room +from synapse.types import RoomAlias +from synapse.util.stringutils import random_string + +from tests import unittest + + +class DirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config.require_membership_for_aliases = True + + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.user = self.register_user("user", "test") + self.user_tok = self.login("user", "test") + + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) + + def test_directory_endpoint_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_directory(403) + + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + + def test_directory_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(400, alias_length=256) + + def test_state_event_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + + def test_directory_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(200) + + def test_room_creation_too_long(self): + url = "/_matrix/client/r0/createRoom" + + # We use deliberately a localpart under the length threshold so + # that we can make sure that the check is done on the whole alias. + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, 400, channel.result) + + def test_room_creation(self): + url = "/_matrix/client/r0/createRoom" + + # Check with an alias of allowed length. There should already be + # a test that ensures it works in test_register.py, but let's be + # as cautious as possible here. + data = {"room_alias_name": random_string(5)} + request_data = json.dumps(data) + request, channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def set_alias_via_directory(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def random_alias(self, length): + return RoomAlias(random_string(length), self.hs.hostname).to_string() + + def ensure_user_left_room(self): + self.ensure_membership("leave") + + def ensure_user_joined_room(self): + self.ensure_membership("join") + + def ensure_membership(self, membership): + try: + if membership == "leave": + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) + if membership == "join": + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 36d8547275..8a9a55a527 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -17,7 +17,8 @@ from mock import Mock, NonCallableMock -from synapse.rest.client.v1 import admin, events, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import events, login, room from tests import unittest @@ -28,7 +29,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): servlets = [ events.register_servlets, room.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 86312f1096..0397f91a9e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,6 +1,7 @@ import json -from synapse.rest.client.v1 import admin, login +import synapse.rest.admin +from synapse.rest.client.v1 import login from tests import unittest @@ -10,7 +11,7 @@ LOGIN_URL = b"/_matrix/client/r0/login" class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] @@ -36,10 +37,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -56,14 +54,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -81,10 +76,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -101,14 +93,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -126,10 +115,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) @@ -146,14 +132,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 1eab9c3bdb..ed034879cf 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -20,7 +20,8 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError -from synapse.rest.client.v1 import profile +from synapse.rest import admin +from synapse.rest.client.v1 import login, profile, room from tests import unittest @@ -42,6 +43,7 @@ class ProfileTestCase(unittest.TestCase): "set_displayname", "get_avatar_url", "set_avatar_url", + "check_profile_query_allowed", ] ) @@ -155,3 +157,77 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") + + +class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config.require_auth_for_profile_requests = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User owning the requested profile. + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.profile_url = "/profile/%s" % (self.owner) + + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) + + def test_no_auth(self): + self.try_fetch_profile(401) + + def test_not_in_shared_room(self): + self.ensure_requester_left_room() + + self.try_fetch_profile(403, access_token=self.requester_tok) + + def test_in_shared_room(self): + self.ensure_requester_left_room() + + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) + + self.try_fetch_profile(200, self.requester_tok) + + def try_fetch_profile(self, expected_code, access_token=None): + self.request_profile(expected_code, access_token=access_token) + + self.request_profile( + expected_code, url_suffix="/displayname", access_token=access_token + ) + + self.request_profile( + expected_code, url_suffix="/avatar_url", access_token=access_token + ) + + def request_profile(self, expected_code, url_suffix="", access_token=None): + request, channel = self.make_request( + "GET", self.profile_url + url_suffix, access_token=access_token + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def ensure_requester_left_room(self): + try: + self.helper.leave( + room=self.room_id, user=self.requester, tok=self.requester_tok + ) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 015c144248..9b191436cc 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -22,8 +22,9 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer +import synapse.rest.admin from synapse.api.constants import Membership -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from tests import unittest @@ -803,7 +804,7 @@ class RoomMessageListTestCase(RoomBase): class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] @@ -903,3 +904,35 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): self.assertEqual( context["profile_info"][self.other_user_id]["displayname"], "otheruser" ) + + +class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/publicRooms" + + config = self.default_config() + config.restrict_public_rooms_to_local_users = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_restricted_no_auth(self): + request, channel = self.make_request("GET", self.url) + self.render(request) + self.assertEqual(channel.code, 401, channel.result) + + def test_restricted_auth(self): + self.register_user("user", "pass") + tok = self.login("user", "pass") + + request, channel = self.make_request("GET", self.url, access_token=tok) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 7fa120a10f..0ca3c4657b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -16,8 +16,8 @@ from twisted.internet.defer import succeed +import synapse.rest.admin from synapse.api.constants import LoginType -from synapse.rest.client.v1 import admin from synapse.rest.client.v2_alpha import auth, register from tests import unittest @@ -27,7 +27,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index bbfc77e829..f3ef977404 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import synapse.rest.admin from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS -from synapse.rest.client.v1 import admin, login +from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities from tests import unittest @@ -23,7 +23,7 @@ from tests import unittest class CapabilitiesTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, login.register_servlets, ] diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index d3611ed21f..be95dc592d 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,14 +1,23 @@ import datetime import json +import os +import pkg_resources + +import synapse.rest.admin from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import admin, login -from synapse.rest.client.v2_alpha import register, sync +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import account_validity, register, sync from tests import unittest +try: + from synapse.push.mailer import load_jinja2_templates +except ImportError: + load_jinja2_templates = None + class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -32,11 +41,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): as_token = "i_am_an_app_service" appservice = ApplicationService( - as_token, self.hs.config.server_name, + as_token, + self.hs.config.server_name, id="1234", - namespaces={ - "users": [{"regex": r"@as_user.*", "exclusive": True}], - }, + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, ) self.hs.get_datastore().services_cache.append(appservice) @@ -48,10 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_invalid(self): @@ -119,10 +124,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) - det_data = { - "home_server": self.hs.hostname, - "device_id": "guest_device", - } + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @@ -150,7 +152,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -178,7 +180,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -190,13 +192,15 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ register.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, + account_validity.register_servlets, ] def make_homeserver(self, reactor, clock): config = self.default_config() + # Test for account expiring after a week. config.enable_registration = True config.account_validity.enabled = True config.account_validity.period = 604800000 # Time in ms for 1 week @@ -210,21 +214,187 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. + request, channel = self.make_request(b"GET", "/sync", access_token=tok) + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + request, channel = self.make_request(b"GET", "/sync", access_token=tok) + self.render(request) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_manual_renewal(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + # If we register the admin user at the beginning of the test, it will + # expire at the same time as the normal user and the renewal request + # will be denied. + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = {"user_id": user_id} + request_data = json.dumps(params) request, channel = self.make_request( - b"GET", "/sync", access_token=tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + request, channel = self.make_request(b"GET", "/sync", access_token=tok) + self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + def test_manual_expire(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_matrix/client/unstable/admin/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) request, channel = self.make_request( - b"GET", "/sync", access_token=tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + request, channel = self.make_request(b"GET", "/sync", access_token=tok) + self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) + + +class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): + + skip = "No Jinja installed" if not load_jinja2_templates else None + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # Test for account expiring after a week and renewal emails being sent 2 + # days before expiry. + config.enable_registration = True + config.account_validity.enabled = True + config.account_validity.renew_by_email_enabled = True + config.account_validity.period = 604800000 # Time in ms for 1 week + config.account_validity.renew_at = 172800000 # Time in ms for 2 days + config.account_validity.renew_email_subject = "Renew your account" + + # Email config. + self.email_attempts = [] + + def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + return + + config.email_template_dir = os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ) + config.email_expiry_template_html = "notice_expiry.html" + config.email_expiry_template_text = "notice_expiry.txt" + config.email_smtp_host = "127.0.0.1" + config.email_smtp_port = 20 + config.require_transport_security = False + config.email_smtp_user = None + config.email_smtp_pass = None + config.email_notif_from = "test@example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + + self.store = self.hs.get_datastore() + + return self.hs + + def test_renewal_email(self): + self.email_attempts = [] + + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.clock.time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + + # Move 6 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=6).total_seconds()) + self.assertEqual(len(self.email_attempts), 1) + + # Retrieving the URL from the email is too much pain for now, so we + # retrieve the token from the DB. + renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + request, channel = self.make_request(b"GET", url) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Move 3 days forward. If the renewal failed, every authed request with + # our access token should be denied from now, otherwise they should + # succeed. + self.reactor.advance(datetime.timedelta(days=3).total_seconds()) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_manual_email_send(self): + self.email_attempts = [] + + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.clock.time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + + request, channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 99b716f00a..71895094bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -15,7 +15,8 @@ from mock import Mock -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests import unittest @@ -72,7 +73,7 @@ class FilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, sync.register_servlets, diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index af8f74eb42..00688a7325 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -26,20 +26,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b'inline; filename="aze%20rty"': u"aze%20rty", b'inline; filename="aze\"rty"': u'aze"rty', b'inline; filename="azer;ty"': u"azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers( - { - b'Content-Disposition': [hdr], - }, - ) + res = get_filename_from_headers({b'Content-Disposition': [hdr]}) self.assertEqual( - res, expected, - "expected output for %s to be %s but was %s" % ( - hdr, expected, res, - ) + res, + expected, + "expected output for %s to be %s but was %s" % (hdr, expected, res), ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 650ce95a6f..f696395f3c 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -297,12 +297,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -318,12 +318,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): request.render(self.preview_url) self.pump() - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -339,7 +339,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 403) self.assertEqual( channel.json_body, { @@ -347,6 +346,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): 'error': 'IP address blocked by IP blacklist entry', }, ) + self.assertEqual(channel.code, 403) def test_blacklisted_ip_range_direct(self): """ @@ -414,12 +414,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) request.render(self.preview_url) self.pump() - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -439,12 +439,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): # No requests made. self.assertEqual(len(self.reactor.tcpClients), 0) - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) @@ -460,11 +460,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): request.render(self.preview_url) self.pump() - self.assertEqual(channel.code, 403) + self.assertEqual(channel.code, 502) self.assertEqual( channel.json_body, { 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + 'error': 'DNS resolution failure during URL preview generation', }, ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 8d8f03e005..b090bb974c 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -31,27 +31,24 @@ class WellKnownTests(unittest.HomeserverTestCase): self.hs.config.default_identity_server = "https://testis" request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) self.assertEqual(request.code, 200) self.assertEqual( - channel.json_body, { + channel.json_body, + { "m.homeserver": {"base_url": "https://tesths"}, "m.identity_server": {"base_url": "https://testis"}, - } + }, ) def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) diff --git a/tests/rulecheck/test_domainrulecheck.py b/tests/rulecheck/test_domainrulecheck.py index 66b9cca4b3..3bb6939b3e 100644 --- a/tests/rulecheck/test_domainrulecheck.py +++ b/tests/rulecheck/test_domainrulecheck.py @@ -16,8 +16,9 @@ import json +import synapse.rest.admin from synapse.config._base import ConfigError -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from synapse.rulecheck.domain_rule_checker import DomainRuleChecker from tests import unittest @@ -155,7 +156,7 @@ class DomainRuleCheckerTestCase(unittest.TestCase): class DomainRuleCheckerRoomTestCase(unittest.HomeserverTestCase): servlets = [ - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/server.py b/tests/server.py index 8f89f4a83d..fc41345488 100644 --- a/tests/server.py +++ b/tests/server.py @@ -182,7 +182,8 @@ def make_request( if federation_auth_origin is not None: req.requestHeaders.addRawHeader( - b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + b"Authorization", + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,), ) if content: @@ -233,7 +234,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): class FakeResolver(object): def getHostByName(self, name, timeout=None): if name not in lookups: - return fail(DNSLookupError("OH NO: unknown %s" % (name, ))) + return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) @@ -454,6 +455,6 @@ class FakeTransport(object): logger.warning("Exception writing to protocol: %s", e) return - self.buffer = self.buffer[len(to_write):] + self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 95badc985e..e0b4e0eb63 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.rest.client.v1 import admin, login, room +import synapse.rest.admin +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests import unittest @@ -23,7 +23,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): servlets = [ sync.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, room.register_servlets, ] diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index be73e718c2..a490b81ed4 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,7 +27,6 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): hs_config = self.default_config("test") hs_config.server_notices_mxid = "@server:test" diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f448b01326..9c5311d916 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -50,6 +50,7 @@ class FakeEvent(object): refer to events. The event_id has node_id as localpart and example.com as domain. """ + def __init__(self, id, sender, type, state_key, content): self.node_id = id self.event_id = EventID(id, "example.com").to_string() @@ -142,24 +143,14 @@ INITIAL_EVENTS = [ content=MEMBERSHIP_CONTENT_JOIN, ), FakeEvent( - id="START", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="START", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), FakeEvent( - id="END", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="END", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), ] -INITIAL_EDGES = [ - "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", -] +INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] class StateTestCase(unittest.TestCase): @@ -170,12 +161,7 @@ class StateTestCase(unittest.TestCase): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="MA", @@ -196,19 +182,11 @@ class StateTestCase(unittest.TestCase): sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), ] - edges = [ - ["END", "MB", "MA", "PA", "START"], - ["END", "PB", "PA"], - ] + edges = [["END", "MB", "MA", "PA", "START"], ["END", "PB", "PA"]] expected_state_ids = ["PA", "MA", "MB"] @@ -232,10 +210,7 @@ class StateTestCase(unittest.TestCase): ), ] - edges = [ - ["END", "JR", "START"], - ["END", "ME", "START"], - ] + edges = [["END", "JR", "START"], ["END", "ME", "START"]] expected_state_ids = ["JR"] @@ -248,45 +223,25 @@ class StateTestCase(unittest.TestCase): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] - edges = [ - ["END", "PC", "PB", "PA", "START"], - ["END", "PA"], - ] + edges = [["END", "PC", "PB", "PA", "START"], ["END", "PA"]] expected_state_ids = ["PC"] @@ -295,68 +250,38 @@ class StateTestCase(unittest.TestCase): def test_topic_basic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), ] - edges = [ - ["END", "PA2", "T2", "PA1", "T1", "START"], - ["END", "T3", "PB", "PA1"], - ] + edges = [["END", "PA2", "T2", "PA1", "T1", "START"], ["END", "T3", "PB", "PA1"]] expected_state_ids = ["PA2", "T2"] @@ -365,30 +290,17 @@ class StateTestCase(unittest.TestCase): def test_topic_reset(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MB", @@ -399,10 +311,7 @@ class StateTestCase(unittest.TestCase): ), ] - edges = [ - ["END", "MB", "T2", "PA", "T1", "START"], - ["END", "T1"], - ] + edges = [["END", "MB", "T2", "PA", "T1", "START"], ["END", "T1"]] expected_state_ids = ["T1", "MB", "PA"] @@ -411,61 +320,34 @@ class StateTestCase(unittest.TestCase): def test_topic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MZ1", @@ -475,11 +357,7 @@ class StateTestCase(unittest.TestCase): content={}, ), FakeEvent( - id="T4", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), ] @@ -587,13 +465,7 @@ class StateTestCase(unittest.TestCase): class LexicographicalTestCase(unittest.TestCase): def test_simple(self): - graph = { - "l": {"o"}, - "m": {"n", "o"}, - "n": {"o"}, - "o": set(), - "p": {"o"}, - } + graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} res = list(lexicographical_topological_sort(graph, key=lambda x: x)) @@ -680,7 +552,13 @@ class SimpleParamStateTestCase(unittest.TestCase): self.expected_combined_state = { (e.type, e.state_key): e.event_id - for e in [create_event, alice_member, join_rules, bob_member, charlie_member] + for e in [ + create_event, + alice_member, + join_rules, + bob_member, + charlie_member, + ] } def test_event_map_none(self): @@ -720,11 +598,7 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return { - eid: self.event_map[eid] - for eid in event_ids - if eid in self.event_map - } + return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} def get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 5568a607c7..fbb9302694 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,9 +9,7 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup - ) + hs = yield setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f18db8c384..c778de1f0c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -56,10 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False hs = TestHomeServer( - "test", - db_pool=self.db_pool, - config=config, - database_engine=fake_engine, + "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) self.datastore = SQLBaseStore(None, hs) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 858efe4992..b62eae7abc 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -18,8 +18,9 @@ from mock import Mock from twisted.internet import defer +import synapse.rest.admin from synapse.http.site import XForwardedForRequest -from synapse.rest.client.v1 import admin, login +from synapse.rest.client.v1 import login from tests import unittest @@ -205,7 +206,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets, login.register_servlets] + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver() diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 11fb8c0c19..cd2bcd4ca3 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -20,7 +20,6 @@ import tests.utils class EndToEndKeyStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6569a82bb..f458c03054 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -56,8 +56,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.register(user_id=user1, token="123", password_hash=None) self.store.register(user_id=user2, token="456", password_hash=None) self.store.register( - user_id=user3, token="789", - password_hash=None, user_type=UserTypes.SUPPORT + user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT ) self.pump() @@ -173,9 +172,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) @@ -187,13 +184,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed( - self.hs.get_clock().time_msec() - ) + return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.store.populate_monthly_active_users('user_id') self.pump() @@ -243,7 +236,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user_id=support_user_id, token="123", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) self.store.upsert_monthly_active_user(support_user_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0fc5019e9f..4823d44dec 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -60,7 +60,7 @@ class RedactionTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,7 +83,7 @@ class RedactionTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"body": body, "msgtype": u"message"}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -105,7 +105,7 @@ class RedactionTestCase(unittest.TestCase): "room_id": room.to_string(), "content": {"reason": reason}, "redacts": event_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index cb3cc4d2e5..c0e0155bb4 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -116,7 +116,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user_id=SUPPORT_USER, token="456", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 063387863e..73ed943f5a 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -58,7 +58,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"membership": membership}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 78e260a7fa..b6169436de 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) @@ -57,7 +56,7 @@ class StateStoreTestCase(tests.unittest.TestCase): "state_key": state_key, "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,15 +82,14 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups_ids( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - { - (EventTypes.Create, ''): e1.event_id, - (EventTypes.Name, ''): e2.event_id, - }, + {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, ) @defer.inlineCallbacks @@ -103,15 +101,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups( - self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] - self.assertEqual( - {ev.event_id for ev in state_list}, - {e1.event_id, e2.event_id}, - ) + self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) @defer.inlineCallbacks def test_get_state_for_event(self): @@ -147,9 +141,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event( - e5.event_id, - ) + state = yield self.store.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -194,7 +186,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, include_others=True, - ) + ), ) self.assertStateMapEqual( @@ -208,9 +200,9 @@ class StateStoreTestCase(tests.unittest.TestCase): # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True ), ) @@ -229,10 +221,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -249,8 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -263,8 +254,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -281,8 +271,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -302,8 +291,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -320,8 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -334,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -384,10 +370,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -399,8 +385,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -413,8 +398,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -425,8 +409,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -445,8 +428,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -457,8 +439,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -471,8 +452,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -483,8 +463,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index fd3361404f..d7d244ce97 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -36,9 +36,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): yield self.store.update_profile_in_user_dir(ALICE, "alice", None) yield self.store.update_profile_in_user_dir(BOB, "bob", None) yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms( - "!room:id", (ALICE, BOB) - ) + yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) @defer.inlineCallbacks def test_search_user_dir(self): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 4c8f87e958..8b2741d277 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,7 +37,9 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(creator), auth_events, + RoomVersions.V1.identifier, + _random_state_event(creator), + auth_events, do_sig_check=False, ) @@ -82,7 +84,9 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(king), auth_events, + RoomVersions.V1.identifier, + _random_state_event(king), + auth_events, do_sig_check=False, ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 1a5dc32c88..6a8339b561 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,4 +1,3 @@ - from mock import Mock from twisted.internet.defer import maybeDeferred, succeed diff --git a/tests/test_mau.py b/tests/test_mau.py index 00be1a8c21..1fbe0d51ff 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -33,9 +33,7 @@ class TestMauLimit(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), + "red", http_client=None, federation_client=Mock() ) self.store = self.hs.get_datastore() @@ -210,9 +208,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request( - "GET", "/sync", access_token=token - ) + request, channel = self.make_request("GET", "/sync", access_token=token) self.render(request) if channel.code != 200: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 0ff6d0e283..2edbae5c6d 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -44,9 +44,7 @@ def get_sample_labels_value(sample): class TestMauLimit(unittest.TestCase): def test_basic(self): gauge = InFlightGauge( - "test1", "", - labels=["test_label"], - sub_metrics=["foo", "bar"], + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] ) def handle1(metrics): @@ -59,37 +57,49 @@ class TestMauLimit(unittest.TestCase): gauge.register(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.register(("key1",), handle1) gauge.register(("key2",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key2",), handle2) gauge.register(("key1",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) def get_metrics_from_gauge(self, gauge): results = {} diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0968e86a7b..f412985d2c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -69,10 +69,10 @@ class TermsTestCase(unittest.HomeserverTestCase): "name": "My Cool Privacy Policy", "url": "https://example.org/_matrix/consent?v=1.0", }, - "version": "1.0" - }, - }, - }, + "version": "1.0", + } + } + } } self.assertIsInstance(channel.json_body["params"], dict) self.assertDictContainsSubset(channel.json_body["params"], expected_params) diff --git a/tests/test_types.py b/tests/test_types.py index d314a7ff58..d83c36559f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -94,8 +94,7 @@ class MapUsernameTestCase(unittest.TestCase): def testSymbols(self): self.assertEqual( - map_username_to_mxid_localpart("test=$?_1234"), - "test=3d=24=3f_1234", + map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" ) def testLeadingUnderscore(self): @@ -105,6 +104,5 @@ class MapUsernameTestCase(unittest.TestCase): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), - "t=c3=aast", + map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index d0bc8e2112..fde0baee8e 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -22,6 +22,7 @@ from synapse.util.logcontext import LoggingContextFilter class ToTwistedHandler(logging.Handler): """logging handler which sends the logs to the twisted log""" + tx_log = twisted.logger.Logger() def emit(self, record): @@ -41,7 +42,8 @@ def setup_logging(): root_logger = logging.getLogger() log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" + "%(asctime)s - %(name)s - %(lineno)d - " + "%(levelname)s - %(request)s - %(message)s" ) handler = ToTwistedHandler() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 3bdb500514..6a180ddc32 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -132,7 +132,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "state_key": "", "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -153,7 +153,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "state_key": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -174,7 +174,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "sender": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/unittest.py b/tests/unittest.py index 8c65736a51..94df8cf47e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -84,9 +84,8 @@ class TestCase(unittest.TestCase): # all future bets are off. if LoggingContext.current_context() is not LoggingContext.sentinel: self.fail( - "Test starting with non-sentinel logging context %s" % ( - LoggingContext.current_context(), - ) + "Test starting with non-sentinel logging context %s" + % (LoggingContext.current_context(),) ) old_level = logging.getLogger().level @@ -181,10 +180,7 @@ class HomeserverTestCase(TestCase): raise Exception("A homeserver wasn't returned, but %r" % (self.hs,)) # Register the resources - self.resource = JsonResource(self.hs) - - for servlet in self.servlets: - servlet(self.hs, self.resource) + self.resource = self.create_test_json_resource() from tests.rest.client.v1.utils import RestHelper @@ -230,6 +226,23 @@ class HomeserverTestCase(TestCase): hs = self.setup_test_homeserver() return hs + def create_test_json_resource(self): + """ + Create a test JsonResource, with the relevant servlets registerd to it + + The default implementation calls each function in `servlets` to do the + registration. + + Returns: + JsonResource: + """ + resource = JsonResource(self.hs) + + for servlet in self.servlets: + servlet(self.hs, resource) + + return resource + def default_config(self, name="test"): """ Get a default HomeServer config object. @@ -286,7 +299,13 @@ class HomeserverTestCase(TestCase): content = json.dumps(content).encode('utf8') return make_request( - self.reactor, method, path, content, access_token, request, shorthand, + self.reactor, + method, + path, + content, + access_token, + request, + shorthand, federation_auth_origin, ) diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index 84dd71e47a..bf85d3b8ec 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -42,10 +42,10 @@ class TimeoutDeferredTest(TestCase): self.assertNoResult(timing_out_d) self.assertFalse(cancelled[0], "deferred was cancelled prematurely") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_times_out_when_canceller_throws(self): """Test that we have successfully worked around @@ -59,9 +59,9 @@ class TimeoutDeferredTest(TestCase): self.assertNoResult(timing_out_d) - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_logcontext_is_preserved_on_cancellation(self): blocking_was_cancelled = [False] @@ -80,10 +80,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), context_one, - "errback %s run in unexpected logcontext %s" % ( - deferred_name, LoggingContext.current_context(), - ) + LoggingContext.current_context(), + context_one, + "errback %s run in unexpected logcontext %s" + % (deferred_name, LoggingContext.current_context()), ) return res @@ -94,11 +94,10 @@ class TimeoutDeferredTest(TestCase): self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) timing_out_d.addErrback(errback, "timingout") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue( - blocking_was_cancelled[0], - "non-completing deferred was not cancelled", + blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(LoggingContext.current_context(), context_one) diff --git a/tests/utils.py b/tests/utils.py index e6e6cb4c75..176093d8c4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -68,7 +68,9 @@ def setupdb(): # connect to postgres to create the base database. db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -94,7 +96,9 @@ def setupdb(): def _cleanup(): db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -114,7 +118,6 @@ def default_config(name): "server_name": name, "media_store_path": "media", "uploads_path": "uploads", - # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", |