diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2096ba3c91..5d41443293 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -133,7 +133,7 @@ def create_test_cert_file(sanlist):
@implementer(IOpenSSLServerConnectionCreator)
-class TestServerTLSConnectionFactory(object):
+class TestServerTLSConnectionFactory:
"""An SSL connection creator which returns connections which present a certificate
signed by our test CA."""
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 562397cdda..8b5ad4574f 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
return test_server_connection_factory
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+ async def resolve_service(_):
+ return result
+
+ return resolve_service
+
+
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -86,6 +94,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_resolver = WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=self.tls_factory),
+ b"test-agent",
well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache,
)
@@ -93,6 +102,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
+ user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided.
_srv_resolver=self.mock_resolver,
_well_known_resolver=self.well_known_resolver,
)
@@ -186,6 +196,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
self._send_well_known_response(request, content, headers=response_headers)
return well_known_server
@@ -231,6 +244,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
)
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+ )
content = request.content.read()
self.assertEqual(content, b"")
@@ -365,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when the certificate on the server doesn't match the hostname
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -448,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has no port, no SRV, and no well-known
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -502,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -564,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the server name has no port and no SRV record, but
the .well-known has a 300 redirect
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@@ -653,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
Test the behaviour when the server name has an *invalid* well-known (and no SRV)
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -709,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
@@ -719,10 +735,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=tls_factory,
+ user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
_srv_resolver=self.mock_resolver,
_well_known_resolver=WellKnownResolver(
self.reactor,
Agent(self.reactor, contextFactory=tls_factory),
+ b"test-agent",
well_known_cache=self.well_known_cache,
had_well_known_cache=self.had_well_known_cache,
),
@@ -754,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
Test the behaviour when there is a single SRV record
"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -809,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"srvtarget", port=8443)
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"srvtarget", port=8443)]
+ )
self._handle_well_known_connection(
client_factory,
@@ -851,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
- self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -912,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
+ )
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -954,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -977,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -985,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1008,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -1034,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# another lookup.
self.reactor.pump((900.0,))
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# The resolver may retry a few times, so fonx all requests that come along
attempts = 0
@@ -1064,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1077,11 +1107,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_srv_fallbacks(self):
"""Test that other SRV results are tried if the first one fails.
"""
-
- self.mock_resolver.resolve_service.side_effect = lambda _: [
- Server(host=b"target.com", port=8443),
- Server(host=b"target.com", port=8444),
- ]
+ self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+ [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ )
self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -1233,7 +1264,7 @@ def _log_request(request):
@implementer(IPolicyForHTTPS)
-class TrustingTLSPolicyForHTTPS(object):
+class TrustingTLSPolicyForHTTPS:
"""An IPolicyForHTTPS which checks that the certificate belongs to the
right server, but doesn't check the certificate chain."""
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
from twisted.names import dns, error
from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
with LoggingContext("one") as ctx:
resolve_d = resolver.resolve_service(service_name)
-
- self.assertNoResult(resolve_d)
-
- # should have reset to the sentinel context
- self.assertIs(current_context(), SENTINEL_CONTEXT)
-
- result = yield resolve_d
+ result = yield defer.ensureDeferred(resolve_d)
# should have restored our context
self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client=dns_client_mock, cache=cache, get_time=clock.time
)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertFalse(dns_client_mock.lookupService.called)
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
with self.assertRaises(error.DNSServerError):
- yield resolver.resolve_service(service_name)
+ yield defer.ensureDeferred(resolver.resolve_service(service_name))
@defer.inlineCallbacks
def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- servers = yield resolver.resolve_service(service_name)
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
cache = {}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
- resolve_d = resolver.resolve_service(service_name)
- self.assertNoResult(resolve_d)
+ # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+ resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
lookup_deferred.callback(
(
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
new file mode 100644
index 0000000000..62d36c2906
--- /dev/null
+++ b/tests/http/test_additional_resource.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.http.additional_resource import AdditionalResource
+from synapse.http.server import respond_with_json
+
+from tests.unittest import HomeserverTestCase
+
+
+class _AsyncTestCustomEndpoint:
+ def __init__(self, config, module_api):
+ pass
+
+ async def handle_request(self, request):
+ respond_with_json(request, 200, {"some_key": "some_value_async"})
+
+
+class _SyncTestCustomEndpoint:
+ def __init__(self, config, module_api):
+ pass
+
+ async def handle_request(self, request):
+ respond_with_json(request, 200, {"some_key": "some_value_sync"})
+
+
+class AdditionalResourceTests(HomeserverTestCase):
+ """Very basic tests that `AdditionalResource` works correctly with sync
+ and async handlers.
+ """
+
+ def test_async(self):
+ handler = _AsyncTestCustomEndpoint({}, None).handle_request
+ self.resource = AdditionalResource(self.hs, handler)
+
+ request, channel = self.make_request("GET", "/")
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
+
+ def test_sync(self):
+ handler = _SyncTestCustomEndpoint({}, None).handle_request
+ self.resource = AdditionalResource(self.hs, handler)
+
+ request, channel = self.make_request("GET", "/")
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..5604af3795 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -16,6 +16,7 @@
from mock import Mock
from netaddr import IPSet
+from parameterized import parameterized
from twisted.internet import defer
from twisted.internet.defer import TimeoutError
@@ -58,7 +59,9 @@ class FederationClientTests(HomeserverTestCase):
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
- fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+ fetch_d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar")
+ )
# Nothing happened yet
self.assertNoResult(fetch_d)
@@ -120,7 +123,9 @@ class FederationClientTests(HomeserverTestCase):
"""
If the DNS lookup returns an error, it will bubble up.
"""
- d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+ )
self.pump()
f = self.failureResultOf(d)
@@ -128,7 +133,9 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -154,7 +161,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -184,7 +193,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -226,7 +237,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a blacklisted IPv4 address
# ------------------------------------------------------
# Make the request
- d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
# Nothing happened yet
self.assertNoResult(d)
@@ -244,7 +255,9 @@ class FederationClientTests(HomeserverTestCase):
# Try making a POST request to a blacklisted IPv6 address
# -------------------------------------------------------
# Make the request
- d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+ )
# Nothing has happened yet
self.assertNoResult(d)
@@ -263,7 +276,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a non-blacklisted IPv4 address
# ----------------------------------------------------------
# Make the request
- d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
# Nothing has happened yet
self.assertNoResult(d)
@@ -286,7 +299,7 @@ class FederationClientTests(HomeserverTestCase):
request = MatrixFederationRequest(
method="GET", destination="testserv:8008", path="foo/bar"
)
- d = self.cl._send_request(request, timeout=10000)
+ d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
self.pump()
@@ -310,7 +323,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
- d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ d = defer.ensureDeferred(
+ self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+ )
self.pump()
@@ -342,7 +357,9 @@ class FederationClientTests(HomeserverTestCase):
requiring a trailing slash. We need to retry the request with a
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -395,7 +412,9 @@ class FederationClientTests(HomeserverTestCase):
See test_client_requires_trailing_slashes() for context.
"""
- d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ d = defer.ensureDeferred(
+ self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+ )
# Send the request
self.pump()
@@ -432,7 +451,11 @@ class FederationClientTests(HomeserverTestCase):
self.failureResultOf(d)
def test_client_sends_body(self):
- self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+ defer.ensureDeferred(
+ self.cl.post_json(
+ "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+ )
+ )
self.pump()
@@ -453,7 +476,7 @@ class FederationClientTests(HomeserverTestCase):
def test_closes_connection(self):
"""Check that the client closes unused HTTP connections"""
- d = self.cl.get_json("testserv:8008", "foo/bar")
+ d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
self.pump()
@@ -486,6 +509,53 @@ class FederationClientTests(HomeserverTestCase):
self.assertFalse(conn.disconnecting)
# wait for a while
- self.pump(120)
+ self.reactor.advance(120)
self.assertTrue(conn.disconnecting)
+
+ @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
+ def test_json_error(self, return_value):
+ """
+ Test what happens if invalid JSON is returned from the remote endpoint.
+ """
+
+ test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+ self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it the HTTP response
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: %i\r\n"
+ b"\r\n"
+ b"%s" % (len(return_value), return_value)
+ )
+
+ self.pump()
+
+ f = self.failureResultOf(test_d)
+ self.assertIsInstance(f.value, ValueError)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
new file mode 100644
index 0000000000..45089158ce
--- /dev/null
+++ b/tests/http/test_servlet.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from io import BytesIO
+
+from mock import Mock
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ parse_json_object_from_request,
+ parse_json_value_from_request,
+)
+
+from tests import unittest
+
+
+def make_request(content):
+ """Make an object that acts enough like a request."""
+ request = Mock(spec=["content"])
+
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
+
+ request.content = BytesIO(content)
+ return request
+
+
+class TestServletUtils(unittest.TestCase):
+ def test_parse_json_value(self):
+ """Basic tests for parse_json_value_from_request."""
+ # Test round-tripping.
+ obj = {"foo": 1}
+ result = parse_json_value_from_request(make_request(obj))
+ self.assertEqual(result, obj)
+
+ # Results don't have to be objects.
+ result = parse_json_value_from_request(make_request(b'["foo"]'))
+ self.assertEqual(result, ["foo"])
+
+ # Test empty.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b""))
+
+ result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
+ self.assertIsNone(result)
+
+ # Invalid UTF-8.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b"\xFF\x00"))
+
+ # Invalid JSON.
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b"foo"))
+
+ with self.assertRaises(SynapseError):
+ parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
+
+ def test_parse_json_object(self):
+ """Basic tests for parse_json_object_from_request."""
+ # Test empty.
+ result = parse_json_object_from_request(
+ make_request(b""), allow_empty_body=True
+ )
+ self.assertEqual(result, {})
+
+ # Test not an object
+ with self.assertRaises(SynapseError):
+ parse_json_object_from_request(make_request(b'["foo"]'))
|