summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-07-30 08:01:33 -0400
committerGitHub <noreply@github.com>2020-07-30 08:01:33 -0400
commitc978f6c4515a631f289aedb1844d8579b9334aaa (patch)
tree105d4069557d4b78c9b983ebfd8581ffad69165d /tests
parentConvert appservice to async. (#7973) (diff)
downloadsynapse-c978f6c4515a631f289aedb1844d8579b9334aaa.tar.xz
Convert federation client to async/await. (#7975)
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py11
-rw-r--r--tests/federation/test_complexity.py21
-rw-r--r--tests/federation/test_federation_sender.py10
-rw-r--r--tests/handlers/test_directory.py5
-rw-r--r--tests/handlers/test_profile.py3
-rw-r--r--tests/http/test_fedclient.py50
-rw-r--r--tests/replication/test_federation_sender_shard.py13
-rw-r--r--tests/rest/admin/test_admin.py4
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/test_federation.py2
10 files changed, 71 insertions, 52 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..e0ad8e8a77 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         }
         persp_deferred = defer.Deferred()
 
-        @defer.inlineCallbacks
-        def get_perspectives(**kwargs):
+        async def get_perspectives(**kwargs):
             self.assertEquals(current_context().request, "11")
             with PreserveLoggingContext():
-                yield persp_deferred
+                await persp_deferred
             return persp_resp
 
         self.http_client.post_json.side_effect = get_perspectives
@@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         }
         signedjson.sign.sign_json(response, SERVER_NAME, testkey)
 
-        def get_json(destination, path, **kwargs):
+        async def get_json(destination, path, **kwargs):
             self.assertEqual(destination, SERVER_NAME)
             self.assertEqual(path, "/_matrix/key/v2/server/key1")
             return response
@@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect a perspectives-server key query
         """
 
-        def post_json(destination, path, data, **kwargs):
+        async def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
             self.assertEqual(path, "/_matrix/key/v2/query")
 
@@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         # remove the perspectives server's signature
         response = build_response()
         del response["signatures"][self.mock_perspective_server.server_name]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
 
         # remove the origin server's signature
         response = build_response()
         del response["signatures"][SERVER_NAME]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
 
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 5cd0510f0d..b8ca118716 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -78,9 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -109,9 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -147,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         # Artificially raise the complexity
@@ -203,9 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -233,9 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d1bd18da39..5f512ff8bf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
         self.pump()
         mock_send_transaction.assert_not_called()
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse
 import synapse.api.errors
 from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
 from synapse.types import RoomAlias, create_requester
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f1347cd25..d70e1fc608 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_other_name(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..ac598249e4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase):
         @defer.inlineCallbacks
         def do_request():
             with LoggingContext("one") as context:
-                fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+                fetch_d = defer.ensureDeferred(
+                    self.cl.get_json("testserv:8008", "foo/bar")
+                )
 
                 # Nothing happened yet
                 self.assertNoResult(fetch_d)
@@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase):
         """
         If the DNS lookup returns an error, it will bubble up.
         """
-        d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        )
         self.pump()
 
         f = self.failureResultOf(d)
@@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value.inner_exception, DNSLookupError)
 
     def test_client_connection_refused(self):
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a blacklisted IPv4 address
         # ------------------------------------------------------
         # Make the request
-        d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
 
         # Nothing happened yet
         self.assertNoResult(d)
@@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a POST request to a blacklisted IPv6 address
         # -------------------------------------------------------
         # Make the request
-        d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        )
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a non-blacklisted IPv4 address
         # ----------------------------------------------------------
         # Make the request
-        d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase):
         request = MatrixFederationRequest(
             method="GET", destination="testserv:8008", path="foo/bar"
         )
-        d = self.cl._send_request(request, timeout=10000)
+        d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
 
         self.pump()
 
@@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase):
         requiring a trailing slash. We need to retry the request with a
         trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase):
 
         See test_client_requires_trailing_slashes() for context.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase):
         self.failureResultOf(d)
 
     def test_client_sends_body(self):
-        self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+        defer.ensureDeferred(
+            self.cl.post_json(
+                "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+            )
+        )
 
         self.pump()
 
@@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase):
 
     def test_closes_connection(self):
         """Check that the client closes unused HTTP connections"""
-        d = self.cl.get_json("testserv:8008", "foo/bar")
+        d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
 
         self.pump()
 
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8d4dbf232e..83f9aa291c 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,8 +16,6 @@ import logging
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.builder import EventBuilderFactory
 from synapse.rest.admin import register_servlets_for_client_rest_resource
@@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new event.
         """
         mock_client = Mock(spec=["put_json"])
-        mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
 
         self.make_worker_hs(
             "synapse.app.federation_sender",
@@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new events.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new typing EDUs.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index b1a4decced..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         self.fetches = []
 
-        def get_file(destination, path, output_stream, args=None, max_size=None):
+        async def get_file(destination, path, output_stream, args=None, max_size=None):
             """
             Returns tuple[int,dict,str,int] of file length, response headers,
             absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             d = Deferred()
             d.addCallback(write_to)
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            return await make_deferred_yieldable(d)
 
         client = Mock()
         client.get_file = get_file
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect an outgoing GET request for the given key
         """
 
-        def get_json(destination, path, ignore_backoff=False, **kwargs):
+        async def get_json(destination, path, ignore_backoff=False, **kwargs):
             self.assertTrue(ignore_backoff)
             self.assertEqual(destination, server_name)
             key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
         # will be forwarded to hs1
-        def post_json(destination, path, data):
+        async def post_json(destination, path, data):
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
                 path, "/_matrix/key/v2/query",
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 87a16d7d7a..c2f12c2741 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         prev_events that said event references.
         """
 
-        def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(destination, path, data, headers=None, timeout=0):
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}