diff --git a/tests/test_server.py b/tests/test_server.py
index e9a43b1e45..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -12,31 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import re
-from six import StringIO
-
from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
-from synapse.http.server import (
- DirectServeResource,
- JsonResource,
- OptionsResource,
- wrap_html_request_handler,
-)
-from synapse.http.site import SynapseSite, logger
+from synapse.config.server import parse_listener_def
+from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
+from synapse.http.site import SynapseSite
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
from tests import unittest
from tests.server import (
- FakeTransport,
ThreadedMemoryReactorClock,
make_request,
render,
@@ -168,6 +157,28 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+ def test_head_request(self):
+ """
+ JsonResource.handler_for_request gives correctly decoded URL args to
+ the callback, while Twisted will give the raw bytes of URL query
+ arguments.
+ """
+
+ def _callback(request, **kwargs):
+ return 200, {"result": True}
+
+ res = JsonResource(self.homeserver)
+ res.register_paths(
+ "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+ )
+
+ # The path was registered as GET, but this is a HEAD request.
+ request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
+
class OptionsResourceTests(unittest.TestCase):
def setUp(self):
@@ -189,7 +200,13 @@ class OptionsResourceTests(unittest.TestCase):
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource.
- site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ site = SynapseSite(
+ "test",
+ "site_tag",
+ parse_listener_def({"type": "http", "port": 0}),
+ self.resource,
+ "1.0",
+ )
request.site = site
resource = site.getResourceFor(request)
@@ -198,10 +215,10 @@ class OptionsResourceTests(unittest.TestCase):
return channel
def test_unknown_options_request(self):
- """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ """An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -218,10 +235,10 @@ class OptionsResourceTests(unittest.TestCase):
)
def test_known_options_request(self):
- """An OPTIONS requests to an known URL still returns 200 OK."""
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/res/")
- self.assertEqual(channel.result["code"], b"200")
- self.assertEqual(channel.result["body"], b"{}")
+ self.assertEqual(channel.result["code"], b"204")
+ self.assertNotIn("body", channel.result)
# Ensure the correct CORS headers have been added
self.assertTrue(
@@ -250,18 +267,17 @@ class OptionsResourceTests(unittest.TestCase):
class WrapHtmlRequestHandlerTests(unittest.TestCase):
- class TestResource(DirectServeResource):
+ class TestResource(DirectServeHtmlResource):
callback = None
- @wrap_html_request_handler
async def _async_render_GET(self, request):
- return await self.callback(request)
+ await self.callback(request)
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self):
- def callback(request):
+ async def callback(request):
request.write(b"response")
request.finish()
@@ -281,7 +297,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
with the right location.
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
@@ -301,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
returned too
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
@@ -319,51 +335,18 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
+ def test_head_request(self):
+ """A head request should work by being turned into a GET request."""
-class SiteTestCase(unittest.HomeserverTestCase):
- def test_lose_connection(self):
- """
- We log the URI correctly redacted when we lose the connection.
- """
+ async def callback(request):
+ request.write(b"response")
+ request.finish()
- class HangingResource(Resource):
- """
- A Resource that strategically hangs, as if it were processing an
- answer.
- """
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
- def render(self, request):
- return NOT_DONE_YET
-
- # Set up a logging handler that we can inspect afterwards
- output = StringIO()
- handler = logging.StreamHandler(output)
- logger.addHandler(handler)
- old_level = logger.level
- logger.setLevel(10)
- self.addCleanup(logger.setLevel, old_level)
- self.addCleanup(logger.removeHandler, handler)
-
- # Make a resource and a Site, the resource will hang and allow us to
- # time out the request while it's 'processing'
- base_resource = Resource()
- base_resource.putChild(b"", HangingResource())
- site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
-
- server = site.buildProtocol(None)
- client = AccumulatingProtocol()
- client.makeConnection(FakeTransport(server, self.reactor))
- server.makeConnection(FakeTransport(client, self.reactor))
-
- # Send a request with an access token that will get redacted
- server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n")
- self.pump()
-
- # Lose the connection
- e = Failure(Exception("Failed123"))
- server.connectionLost(e)
- handler.flush()
-
- # Our access token is redacted and the failure reason is logged.
- self.assertIn("/?access_token=<redacted>", output.getvalue())
- self.assertIn("Failed123", output.getvalue())
+ request, channel = make_request(self.reactor, b"HEAD", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
|