diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 578fc48ef4..efecb089c1 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
- def __init__(self, msg):
+ def __init__(self, msg: str):
super().__init__(504, msg)
@@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
-def redact_uri(uri):
+def redact_uri(uri: str) -> str:
"""Strips sensitive information from the uri replaces with <redacted>"""
uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
@@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
https://twistedmatrix.com/trac/ticket/6528
"""
- def stopProducing(self):
+ def stopProducing(self) -> None:
try:
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 9a2684aca4..6a9f6635d2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request
@@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""
- def __init__(self, hs: "HomeServer", handler):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
+ ):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
@@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource):
super().__init__()
self._handler = handler
- def _async_render(self, request: Request):
+ async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
- return self._handler(request)
+ return await self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b5a2d333a6..ca33b45cb2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
+from http import HTTPStatus
from io import BytesIO
from typing import (
TYPE_CHECKING,
@@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
- e = SynapseError(403, "IP address blocked by IP blacklist entry")
+ e = SynapseError(
+ HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
+ )
return defer.fail(Failure(e))
return self._agent.request(
@@ -585,7 +588,7 @@ class SimpleHttpClient:
if headers:
actual_headers.update(headers) # type: ignore
- body = await self.get_raw(uri, args, headers=headers)
+ body = await self.get_raw(uri, args, headers=actual_headers)
return json_decoder.decode(body.decode("utf-8"))
async def put_json(
@@ -719,7 +722,9 @@ class SimpleHttpClient:
if response.code > 299:
logger.warning("Got %d when downloading %s" % (response.code, url))
- raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
+ )
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
@@ -731,12 +736,14 @@ class SimpleHttpClient:
)
except BodyExceededMaxSize:
raise SynapseError(
- 502,
+ HTTPStatus.BAD_GATEWAY,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
except Exception as e:
- raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
+ ) from e
return (
length,
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 1238bfd287..a8a520f809 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -25,6 +25,7 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import (
+ IProtocol,
IProtocolFactory,
IReactorCore,
IStreamClientEndpoint,
@@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
- def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
+ def connect(
+ self, protocol_factory: IProtocolFactory
+ ) -> "defer.Deferred[IProtocol]":
"""Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)
- async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
first_exception = None
server_list = await self._resolve_server()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 203d723d41..deedde0b5b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,6 +19,7 @@ import random
import sys
import typing
import urllib.parse
+from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
@@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient:
request.destination,
msg,
)
- raise SynapseError(502, msg, Codes.TOO_LARGE)
+ raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 91badb0b0a..09b4125489 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,7 +14,6 @@
# limitations under the License.
import abc
-import collections
import html
import logging
import types
@@ -30,12 +29,14 @@ from typing import (
Iterable,
Iterator,
List,
+ NoReturn,
Optional,
Pattern,
Tuple,
Union,
)
+import attr
import jinja2
from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
@@ -57,12 +58,14 @@ from synapse.api.errors import (
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
-from synapse.logging.opentracing import trace_servlet
+from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
+ import opentracing
+
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -170,7 +173,9 @@ def return_html_error(
respond_with_html(request, code, body)
-def wrap_async_request_handler(h):
+def wrap_async_request_handler(
+ h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
+) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
@@ -183,7 +188,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""
- async def wrapped_async_request_handler(self, request):
+ async def wrapped_async_request_handler(
+ self: "_AsyncResource", request: SynapseRequest
+ ) -> None:
with request.processing():
await h(self, request)
@@ -240,18 +247,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
context from the request the servlet is handling.
"""
- def __init__(self, extract_context=False):
+ def __init__(self, extract_context: bool = False):
super().__init__()
self._extract_context = extract_context
- def render(self, request):
+ def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
- async def _async_render_wrapper(self, request: SynapseRequest):
+ async def _async_render_wrapper(self, request: SynapseRequest) -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@@ -271,7 +278,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
f = failure.Failure()
self._send_error_response(f, request)
- async def _async_render(self, request: Request):
+ async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
@@ -318,7 +325,7 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""
- def __init__(self, canonical_json=False, extract_context=False):
+ def __init__(self, canonical_json: bool = False, extract_context: bool = False):
super().__init__(extract_context)
self.canonical_json = canonical_json
@@ -327,7 +334,7 @@ class DirectServeJsonResource(_AsyncResource):
request: SynapseRequest,
code: int,
response_object: Any,
- ):
+ ) -> None:
"""Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
@@ -347,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource):
return_json_error(f, request)
-_PathEntry = collections.namedtuple(
- "_PathEntry", ["pattern", "callback", "servlet_classname"]
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PathEntry:
+ pattern: Pattern
+ callback: ServletCallback
+ servlet_classname: str
class JsonResource(DirectServeJsonResource):
@@ -368,34 +377,45 @@ class JsonResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ canonical_json: bool = True,
+ extract_context: bool = False,
+ ):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs
- def register_paths(self, method, path_patterns, callback, servlet_classname):
+ def register_paths(
+ self,
+ method: str,
+ path_patterns: Iterable[Pattern],
+ callback: ServletCallback,
+ servlet_classname: str,
+ ) -> None:
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.
Args:
- method (str): GET, POST etc
+ method: GET, POST etc
- path_patterns (Iterable[str]): A list of regular expressions to which
- the request URLs are compared.
+ path_patterns: A list of regular expressions to which the request
+ URLs are compared.
- callback (function): The handler for the request. Usually a Servlet
+ callback: The handler for the request. Usually a Servlet
- servlet_classname (str): The name of the handler to be used in prometheus
+ servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
- method = method.encode("utf-8") # method is bytes on py3
+ method_bytes = method.encode("utf-8")
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
- self.path_regexs.setdefault(method, []).append(
+ self.path_regexs.setdefault(method_bytes, []).append(
_PathEntry(path_pattern, callback, servlet_classname)
)
@@ -427,7 +447,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}
- async def _async_render(self, request):
+ async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appropriate name for this handler in prometheus
@@ -468,7 +488,7 @@ class DirectServeHtmlResource(_AsyncResource):
request: SynapseRequest,
code: int,
response_object: Any,
- ):
+ ) -> None:
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
@@ -492,12 +512,12 @@ class StaticResource(File):
Differs from the File resource by adding clickjacking protection.
"""
- def render_GET(self, request: Request):
+ def render_GET(self, request: Request) -> bytes:
set_clickjacking_protection_headers(request)
return super().render_GET(request)
-def _unrecognised_request_handler(request):
+def _unrecognised_request_handler(request: Request) -> NoReturn:
"""Request handler for unrecognised requests
This is a request handler suitable for return from
@@ -505,7 +525,7 @@ def _unrecognised_request_handler(request):
UnrecognizedRequestError.
Args:
- request (twisted.web.http.Request):
+ request: Unused, but passed in to match the signature of ServletCallback.
"""
raise UnrecognizedRequestError()
@@ -513,23 +533,23 @@ def _unrecognised_request_handler(request):
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
- def __init__(self, path):
- resource.Resource.__init__(self)
+ def __init__(self, path: str):
+ super().__init__()
self.url = path
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
return redirectTo(self.url.encode("ascii"), request)
- def getChild(self, name, request):
+ def getChild(self, name: str, request: Request) -> resource.Resource:
if len(name) == 0:
return self # select ourselves as the child to render
- return resource.Resource.getChild(self, name, request)
+ return super().getChild(name, request)
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
- def render_OPTIONS(self, request):
+ def render_OPTIONS(self, request: Request) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@@ -537,10 +557,10 @@ class OptionsResource(resource.Resource):
return b""
- def getChildWithDefault(self, path, request):
+ def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
if request.method == b"OPTIONS":
return self # select ourselves as the child to render
- return resource.Resource.getChildWithDefault(self, path, request)
+ return super().getChildWithDefault(path, request)
class RootOptionsRedirectResource(OptionsResource, RootRedirect):
@@ -649,7 +669,7 @@ def respond_with_json(
json_object: Any,
send_cors: bool = False,
canonical_json: bool = True,
-):
+) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
@@ -696,7 +716,7 @@ def respond_with_json_bytes(
code: int,
json_bytes: bytes,
send_cors: bool = False,
-):
+) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
@@ -713,7 +733,7 @@ def respond_with_json_bytes(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
- return
+ return None
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -731,7 +751,7 @@ async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
-):
+) -> None:
"""Encodes the given JSON object on a thread and then writes it to the
request.
@@ -743,7 +763,20 @@ async def _async_write_json_to_request_in_thread(
expensive.
"""
- json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
+ def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes:
+ # it might take a while for the threadpool to schedule us, so we write
+ # opentracing logs once we actually get scheduled, so that we can see how
+ # much that contributed.
+ if opentracing_span:
+ opentracing_span.log_kv({"event": "scheduled"})
+ res = json_encoder(json_object)
+ if opentracing_span:
+ opentracing_span.log_kv({"event": "encoded"})
+ return res
+
+ with start_active_span("encode_json_response"):
+ span = active_span()
+ json_str = await defer_to_thread(request.reactor, encode, span)
_write_bytes_to_request(request, json_str)
@@ -773,7 +806,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
-def set_cors_headers(request: Request):
+def set_cors_headers(request: Request) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -790,14 +823,14 @@ def set_cors_headers(request: Request):
)
-def respond_with_html(request: Request, code: int, html: str):
+def respond_with_html(request: Request, code: int, html: str) -> None:
"""
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
"""
respond_with_html_bytes(request, code, html.encode("utf-8"))
-def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
+def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
"""
Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
@@ -815,7 +848,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
- return
+ return None
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@@ -828,7 +861,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
finish_request(request)
-def set_clickjacking_protection_headers(request: Request):
+def set_clickjacking_protection_headers(request: Request) -> None:
"""
Set headers to guard against clickjacking of embedded content.
@@ -850,7 +883,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
finish_request(request)
-def finish_request(request: Request):
+def finish_request(request: Request) -> None:
"""Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6dd9b9ad03..4ff840ca0e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,6 +14,7 @@
""" This module contains base REST classes for constructing REST servlets. """
import logging
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Iterable,
@@ -30,6 +31,7 @@ from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder
@@ -137,11 +139,15 @@ def parse_integer_from_args(
return int(args[name_bytes][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
- raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
else:
if required:
message = "Missing integer query parameter %r" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+ )
else:
return default
@@ -246,11 +252,15 @@ def parse_boolean_from_args(
message = (
"Boolean query parameter %r must be one of ['true', 'false']"
) % (name,)
- raise SynapseError(400, message)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+ )
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+ )
else:
return default
@@ -313,7 +323,7 @@ def parse_bytes_from_args(
return args[name_bytes][0]
elif required:
message = "Missing string query parameter %s" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
return default
@@ -407,14 +417,16 @@ def _parse_string_value(
try:
value_str = value.decode(encoding)
except ValueError:
- raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
+ )
if allowed_values is not None and value_str not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name,
", ".join(repr(v) for v in allowed_values),
)
- raise SynapseError(400, message)
+ raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
else:
return value_str
@@ -510,7 +522,9 @@ def parse_strings_from_args(
else:
if required:
message = "Missing string query parameter %r" % (name,)
- raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+ )
return default
@@ -638,7 +652,7 @@ def parse_json_value_from_request(
try:
content_bytes = request.content.read() # type: ignore
except Exception:
- raise SynapseError(400, "Error reading JSON content.")
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
@@ -647,7 +661,9 @@ def parse_json_value_from_request(
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
- raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
+ )
return content
@@ -673,7 +689,7 @@ def parse_json_object_from_request(
if not isinstance(content, dict):
message = "Content must be a JSON object."
- raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+ raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
return content
@@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent.append(k)
if len(absent) > 0:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
+ )
class RestServlet:
@@ -709,7 +727,7 @@ class RestServlet:
into the appropriate HTTP response.
"""
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
"""Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
@@ -758,10 +776,12 @@ class ResolveRoomIdMixin:
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ HTTPStatus.BAD_REQUEST,
+ "%s was not legal room ID or room alias" % (room_identifier,),
)
if not resolved_room_id:
raise SynapseError(
- 400, "Unknown room ID or room alias %s" % room_identifier
+ HTTPStatus.BAD_REQUEST,
+ "Unknown room ID or room alias %s" % room_identifier,
)
return resolved_room_id, remote_room_hosts
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 755ad56637..80f7a2ff58 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Generator, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
@@ -35,6 +35,9 @@ from synapse.logging.context import (
)
from synapse.types import Requester
+if TYPE_CHECKING:
+ import opentracing
+
logger = logging.getLogger(__name__)
_next_request_seq = 0
@@ -66,9 +69,9 @@ class SynapseRequest(Request):
self,
channel: HTTPChannel,
site: "SynapseSite",
- *args,
+ *args: Any,
max_request_body_size: int = 1024,
- **kw,
+ **kw: Any,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
@@ -81,6 +84,10 @@ class SynapseRequest(Request):
# server name, for client requests this is the Requester object.
self._requester: Optional[Union[Requester, str]] = None
+ # An opentracing span for this request. Will be closed when the request is
+ # completely processed.
+ self._opentracing_span: "Optional[opentracing.Span]" = None
+
# we can't yet create the logcontext, as we don't know the method.
self.logcontext: Optional[LoggingContext] = None
@@ -148,6 +155,13 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
+ def set_opentracing_span(self, span: "opentracing.Span") -> None:
+ """attach an opentracing span to this request
+
+ Doing so will cause the span to be closed when we finish processing the request
+ """
+ self._opentracing_span = span
+
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
@@ -286,6 +300,9 @@ class SynapseRequest(Request):
self._processing_finished_time = time.time()
self._is_processing = False
+ if self._opentracing_span:
+ self._opentracing_span.log_kv({"event": "finished processing"})
+
# if we've already sent the response, log it now; otherwise, we wait for the
# response to be sent.
if self.finish_time is not None:
@@ -299,6 +316,8 @@ class SynapseRequest(Request):
"""
self.finish_time = time.time()
Request.finish(self)
+ if self._opentracing_span:
+ self._opentracing_span.log_kv({"event": "response sent"})
if not self._is_processing:
assert self.logcontext is not None
with PreserveLoggingContext(self.logcontext):
@@ -333,6 +352,11 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
logger.info("Connection from client lost before response was sent")
+ if self._opentracing_span:
+ self._opentracing_span.log_kv(
+ {"event": "client connection lost", "reason": str(reason.value)}
+ )
+
if not self._is_processing:
self._finished_processing()
@@ -421,6 +445,10 @@ class SynapseRequest(Request):
usage.evt_db_fetch_count,
)
+ # complete the opentracing span, if any.
+ if self._opentracing_span:
+ self._opentracing_span.finish()
+
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
@@ -557,7 +585,7 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
- def request_factory(channel, queued: bool) -> Request:
+ def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel,
self,
|