summary refs log tree commit diff
path: root/tests/http
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/http/__init__.py2
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py87
-rw-r--r--tests/http/federation/test_srv_resolver.py26
-rw-r--r--tests/http/test_additional_resource.py62
-rw-r--r--tests/http/test_fedclient.py100
-rw-r--r--tests/http/test_servlet.py80
6 files changed, 297 insertions, 60 deletions
diff --git a/tests/http/__init__.py b/tests/http/__init__.py

index 2096ba3c91..5d41443293 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py
@@ -133,7 +133,7 @@ def create_test_cert_file(sanlist): @implementer(IOpenSSLServerConnectionCreator) -class TestServerTLSConnectionFactory(object): +class TestServerTLSConnectionFactory: """An SSL connection creator which returns connections which present a certificate signed by our test CA.""" diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 562397cdda..8b5ad4574f 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory(): return test_server_connection_factory +# Once Async Mocks or lambdas are supported this can go away. +def generate_resolve_service(result): + async def resolve_service(_): + return result + + return resolve_service + + class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -86,6 +94,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.well_known_resolver = WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=self.tls_factory), + b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ) @@ -93,6 +102,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -186,6 +196,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) self._send_well_known_response(request, content, headers=response_headers) return well_known_server @@ -231,6 +244,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) content = request.content.read() self.assertEqual(content, b"") @@ -365,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv1/foo/bar") @@ -448,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has no port, no SRV, and no well-known """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -502,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known delegates elsewhere """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -564,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -653,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -709,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -719,10 +735,12 @@ class MatrixFederationAgentTests(unittest.TestCase): agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=tls_factory, + user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( self.reactor, Agent(self.reactor, contextFactory=tls_factory), + b"test-agent", well_known_cache=self.well_known_cache, had_well_known_cache=self.had_well_known_cache, ), @@ -754,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self.reactor.lookups["srvtarget"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -809,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443) - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"srvtarget", port=8443)] + ) self._handle_well_known_connection( client_factory, @@ -851,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" - self.mock_resolver.resolve_service.side_effect = lambda _: [] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" @@ -912,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"xn--trget-3qa.com", port=8443) # târget.com - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com + ) self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") @@ -954,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -977,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -985,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1008,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1034,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # another lookup. self.reactor.pump((900.0,)) - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 @@ -1064,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1077,11 +1107,12 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails. """ - - self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"target.com", port=8443), - Server(host=b"target.com", port=8444), - ] + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + ) self.reactor.lookups["target.com"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv/foo/bar") @@ -1233,7 +1264,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) -class TrustingTLSPolicyForHTTPS(object): +class TrustingTLSPolicyForHTTPS: """An IPolicyForHTTPS which checks that the certificate belongs to the right server, but doesn't check the certificate chain.""" diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context +from synapse.logging.context import LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase): with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) - - self.assertNoResult(resolve_d) - - # should have reset to the sentinel context - self.assertIs(current_context(), SENTINEL_CONTEXT) - - result = yield resolve_d + result = yield defer.ensureDeferred(resolve_d) # should have restored our context self.assertIs(current_context(), ctx) @@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) dns_client_mock.lookupService.assert_called_once_with(service_name) @@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client=dns_client_mock, cache=cache, get_time=clock.time ) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertFalse(dns_client_mock.lookupService.called) @@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase): resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) with self.assertRaises(error.DNSServerError): - yield resolver.resolve_service(service_name) + yield defer.ensureDeferred(resolver.resolve_service(service_name)) @defer.inlineCallbacks def test_name_error(self): @@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - servers = yield resolver.resolve_service(service_name) + servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) @@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in failureResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) # returning a single "." should make the lookup fail with a ConenctError lookup_deferred.callback( @@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase): cache = {} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) - resolve_d = resolver.resolve_service(service_name) - self.assertNoResult(resolve_d) + # Old versions of Twisted don't have an ensureDeferred in successResultOf. + resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name)) lookup_deferred.callback( ( diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py new file mode 100644
index 0000000000..62d36c2906 --- /dev/null +++ b/tests/http/test_additional_resource.py
@@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.http.additional_resource import AdditionalResource +from synapse.http.server import respond_with_json + +from tests.unittest import HomeserverTestCase + + +class _AsyncTestCustomEndpoint: + def __init__(self, config, module_api): + pass + + async def handle_request(self, request): + respond_with_json(request, 200, {"some_key": "some_value_async"}) + + +class _SyncTestCustomEndpoint: + def __init__(self, config, module_api): + pass + + async def handle_request(self, request): + respond_with_json(request, 200, {"some_key": "some_value_sync"}) + + +class AdditionalResourceTests(HomeserverTestCase): + """Very basic tests that `AdditionalResource` works correctly with sync + and async handlers. + """ + + def test_async(self): + handler = _AsyncTestCustomEndpoint({}, None).handle_request + self.resource = AdditionalResource(self.hs, handler) + + request, channel = self.make_request("GET", "/") + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) + + def test_sync(self): + handler = _SyncTestCustomEndpoint({}, None).handle_request + self.resource = AdditionalResource(self.hs, handler) + + request, channel = self.make_request("GET", "/") + self.render(request) + + self.assertEqual(request.code, 200) + self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..5604af3795 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -16,6 +16,7 @@ from mock import Mock from netaddr import IPSet +from parameterized import parameterized from twisted.internet import defer from twisted.internet.defer import TimeoutError @@ -58,7 +59,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +123,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +133,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +161,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +193,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +237,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +255,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +276,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +299,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +323,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +357,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +412,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +451,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +476,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() @@ -486,6 +509,53 @@ class FederationClientTests(HomeserverTestCase): self.assertFalse(conn.disconnecting) # wait for a while - self.pump(120) + self.reactor.advance(120) self.assertTrue(conn.disconnecting) + + @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)]) + def test_json_error(self, return_value): + """ + Test what happens if invalid JSON is returned from the remote endpoint. + """ + + test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) + + self.pump() + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8008) + + # complete the connection and wire it up to a fake transport + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # that should have made it send the request to the transport + self.assertRegex(transport.value(), b"^GET /foo/bar") + self.assertRegex(transport.value(), b"Host: testserv:8008") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # Send it the HTTP response + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: %i\r\n" + b"\r\n" + b"%s" % (len(return_value), return_value) + ) + + self.pump() + + f = self.failureResultOf(test_d) + self.assertIsInstance(f.value, ValueError) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py new file mode 100644
index 0000000000..45089158ce --- /dev/null +++ b/tests/http/test_servlet.py
@@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from io import BytesIO + +from mock import Mock + +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + parse_json_object_from_request, + parse_json_value_from_request, +) + +from tests import unittest + + +def make_request(content): + """Make an object that acts enough like a request.""" + request = Mock(spec=["content"]) + + if isinstance(content, dict): + content = json.dumps(content).encode("utf8") + + request.content = BytesIO(content) + return request + + +class TestServletUtils(unittest.TestCase): + def test_parse_json_value(self): + """Basic tests for parse_json_value_from_request.""" + # Test round-tripping. + obj = {"foo": 1} + result = parse_json_value_from_request(make_request(obj)) + self.assertEqual(result, obj) + + # Results don't have to be objects. + result = parse_json_value_from_request(make_request(b'["foo"]')) + self.assertEqual(result, ["foo"]) + + # Test empty. + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b"")) + + result = parse_json_value_from_request(make_request(b""), allow_empty_body=True) + self.assertIsNone(result) + + # Invalid UTF-8. + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b"\xFF\x00")) + + # Invalid JSON. + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b"foo")) + + with self.assertRaises(SynapseError): + parse_json_value_from_request(make_request(b'{"foo": Infinity}')) + + def test_parse_json_object(self): + """Basic tests for parse_json_object_from_request.""" + # Test empty. + result = parse_json_object_from_request( + make_request(b""), allow_empty_body=True + ) + self.assertEqual(result, {}) + + # Test not an object + with self.assertRaises(SynapseError): + parse_json_object_from_request(make_request(b'["foo"]'))