diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8324632cb6..e5b13593f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
+import urllib.parse
from io import BytesIO
from typing import (
+ TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress
+from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
+ IAddress,
+ IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
-def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+def check_against_blacklist(
+ ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
"""
+ Compares an IP address to allowed and disallowed IP sets.
+
Args:
- ip_address (netaddr.IPAddress)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ ip_address: The IP address to check
+ ip_whitelist: Allowed IP addresses.
+ ip_blacklist: Disallowed IP addresses.
+
+ Returns:
+ True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
- def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
"""
Args:
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ reactor: The twisted reactor.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def resolveHostName(self, recv, hostname, portNumber=0):
+ def resolveHostName(
+ self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+ ) -> IResolutionReceiver:
r = recv()
- addresses = []
+ addresses = [] # type: List[IAddress]
- def _callback():
+ def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
- def resolutionBegan(resolutionInProgress):
+ def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
- def addressResolved(address):
+ def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
- def resolutionComplete():
+ def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
- def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ agent: IAgent,
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ ):
"""
Args:
- agent (twisted.web.client.Agent): The Agent to wrap.
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ agent: The Agent to wrap.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:
def __init__(
self,
- hs,
- treq_args={},
- ip_whitelist=None,
- ip_blacklist=None,
- http_proxy=None,
- https_proxy=None,
+ hs: "HomeServer",
+ treq_args: Dict[str, Any] = {},
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ http_proxy: Optional[bytes] = None,
+ https_proxy: Optional[bytes] = None,
):
"""
Args:
- hs (synapse.server.HomeServer)
- treq_args (dict): Extra keyword arguments to be given to treq.request.
- ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ hs
+ treq_args: Extra keyword arguments to be given to treq.request.
+ ip_blacklist: The IP addresses that are blacklisted that
we may not request.
- ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy (bytes): proxy server to use for http connections. host[:port]
- https_proxy (bytes): proxy server to use for https connections. host[:port]
+ http_proxy: proxy server to use for http connections. host[:port]
+ https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -306,7 +336,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
- self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@@ -359,7 +388,7 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
- **self._extra_treq_args
+ **self._extra_treq_args,
) # type: defer.Deferred
# we use our own timeout mechanism rather than treq's as a workaround
@@ -397,7 +426,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
- args: Mapping[str, Union[str, List[str]]] = {},
+ args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -422,9 +451,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
- "utf8"
- )
+ query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -432,7 +459,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ class SimpleHttpClient:
)
async def get_json(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@@ -516,7 +546,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
- args: QueryParams = {},
+ args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@@ -546,9 +576,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@@ -558,7 +588,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ class SimpleHttpClient:
)
async def get_raw(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@@ -592,13 +625,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@@ -641,7 +674,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@@ -649,12 +682,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
+ and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
- "Requested file is too large > %r bytes" % (self.max_size,),
+ "Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@@ -668,7 +702,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- _readBodyToFile(response, output_stream, max_size)
+ readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
-
-
class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
+ def __init__(
+ self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+ ):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
+def readBodyToFile(
+ response: IResponse, stream: BinaryIO, max_size: Optional[int]
+) -> defer.Deferred:
+ """
+ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ max_size: The maximum file size to allow.
+
+ Returns:
+ A Deferred which resolves to the length of the read body.
+ """
-def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
-def encode_urlencode_args(args):
- return {k: encode_urlencode_arg(v) for k, v in args.items()}
+def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+ """
+ Encodes a map of query arguments to bytes which can be appended to a URL.
+ Args:
+ args: The query arguments, a mapping of string to string or list of strings.
+
+ Returns:
+ The query arguments encoded as bytes.
+ """
+ if args is None:
+ return b""
-def encode_urlencode_arg(arg):
- if isinstance(arg, str):
- return arg.encode("utf-8")
- elif isinstance(arg, list):
- return [encode_urlencode_arg(i) for i in arg]
- else:
- return arg
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, str):
+ vs = [vs]
+ encoded_args[k] = [v.encode("utf8") for v in vs]
+ query_str = urllib.parse.urlencode(encoded_args, True)
-def _print_ex(e):
- if hasattr(e, "reasons") and e.reasons:
- for ex in e.reasons:
- _print_ex(ex)
- else:
- logger.exception(e)
+ return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..e77f9587d0 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-import urllib
-from typing import List
+import urllib.parse
+from typing import List, Optional
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import Agent, HTTPConnectionPool
+from twisted.internet.interfaces import (
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
+from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -44,30 +48,30 @@ class MatrixFederationAgent:
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
- reactor (IReactor): twisted reactor to use for underlying requests
+ reactor: twisted reactor to use for underlying requests
- tls_client_options_factory (FederationPolicyForHTTPS|None):
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- user_agent (bytes):
+ user_agent:
The user agent header to use for federation requests.
- _srv_resolver (SrvResolver|None):
- SRVResolver impl to use for looking up SRV records. None to use a default
- implementation.
+ _srv_resolver:
+ SrvResolver implementation to use for looking up SRV records. None
+ to use a default implementation.
- _well_known_resolver (WellKnownResolver|None):
+ _well_known_resolver:
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
self,
- reactor,
- tls_client_options_factory,
- user_agent,
- _srv_resolver=None,
- _well_known_resolver=None,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ user_agent: bytes,
+ _srv_resolver: Optional[SrvResolver] = None,
+ _well_known_resolver: Optional[WellKnownResolver] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -99,15 +103,20 @@ class MatrixFederationAgent:
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
"""
Args:
- method (bytes): HTTP method: GET/POST/etc
- uri (bytes): Absolute URI to be retrieved
- headers (twisted.web.http_headers.Headers|None):
- HTTP headers to send with the request, or None to
- send no extra headers.
- bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ method: HTTP method: GET/POST/etc
+ uri: Absolute URI to be retrieved
+ headers:
+ HTTP headers to send with the request, or None to send no extra headers.
+ bodyProducer:
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
@@ -123,6 +132,9 @@ class MatrixFederationAgent:
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
+ # There must be a valid hostname.
+ assert parsed_uri.hostname
+
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
@@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: Optional[SrvResolver],
+ ):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
resolution (i.e. via SRV). Does not check for well-known delegation.
Args:
- reactor (IReactor)
- tls_client_options_factory (ClientTLSOptionsFactory|None):
+ reactor: twisted reactor to use for underlying requests
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- srv_resolver (SrvResolver): The SRV resolver to use
- parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
- to connect to.
+ srv_resolver: The SRV resolver to use
+ parsed_uri: The parsed URI that we're wanting to connect to.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: SrvResolver,
+ parsed_uri: URI,
+ ):
self._reactor = reactor
self._parsed_uri = parsed_uri
@@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
- def connect(self, protocol_factory):
+ def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
- async def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
first_exception = None
server_list = await self._resolve_server()
@@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
return [Server(host, 8448)]
-def _is_ip_literal(host):
+def _is_ip_literal(host: bytes) -> bool:
"""Test if the given host name is either an IPv4 or IPv6 literal.
Args:
- host (bytes)
+ host: The host name to check
Returns:
- bool
+ True if the hostname is an IP address literal.
"""
- host = host.decode("ascii")
+ host_str = host.decode("ascii")
try:
- IPAddress(host)
+ IPAddress(host_str)
return True
except AddrFormatError:
return False
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index a306faa267..5e08ef1664 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
import time
@@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@@ -81,11 +81,11 @@ class WellKnownResolver:
def __init__(
self,
- reactor,
- agent,
- user_agent,
- well_known_cache=None,
- had_well_known_cache=None,
+ reactor: IReactorTime,
+ agent: IAgent,
+ user_agent: bytes,
+ well_known_cache: Optional[TTLCache] = None,
+ had_well_known_cache: Optional[TTLCache] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -127,7 +127,7 @@ class WellKnownResolver:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
- ) # type: Tuple[Optional[bytes], float]
+ ) # type: Optional[bytes], float
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
@@ -172,7 +172,7 @@ class WellKnownResolver:
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
# We do this in two steps to differentiate between possibly transient
- # errors (e.g. can't connect to host, 503 response) and more permenant
+ # errors (e.g. can't connect to host, 503 response) and more permanent
# errors (such as getting a 404 response).
response, body = await self._make_well_known_request(
server_name, retry=had_valid_well_known
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c23a4d7c0c..4e27f93b7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,8 +17,9 @@ import cgi
import logging
import random
import sys
-import urllib
+import urllib.parse
from io import BytesIO
+from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@@ -27,25 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
-from twisted.internet import defer, protocol
+from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
-from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
- Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
- SynapseError,
)
from synapse.http import QuieterFileBodyProducer
-from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ IPBlacklistingResolver,
+ encode_query_args,
+ readBodyToFile,
+)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@@ -54,6 +57,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -76,47 +80,44 @@ MAXINT = sys.maxsize
_next_id = 1
+QueryArgs = Dict[str, Union[str, List[str]]]
+
+
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
- method = attr.ib()
+ method = attr.ib(type=str)
"""HTTP method
- :type: str
"""
- path = attr.ib()
+ path = attr.ib(type=str)
"""HTTP path
- :type: str
"""
- destination = attr.ib()
+ destination = attr.ib(type=str)
"""The remote server to send the HTTP request to.
- :type: str"""
+ """
- json = attr.ib(default=None)
+ json = attr.ib(default=None, type=Optional[JsonDict])
"""JSON to send in the body.
- :type: dict|None
"""
- json_callback = attr.ib(default=None)
+ json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
"""A callback to generate the JSON.
- :type: func|None
"""
- query = attr.ib(default=None)
+ query = attr.ib(default=None, type=Optional[dict])
"""Query arguments.
- :type: dict|None
"""
- txn_id = attr.ib(default=None)
+ txn_id = attr.ib(default=None, type=Optional[str])
"""Unique ID for this request (for logging)
- :type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
global _next_id
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
@@ -136,7 +137,7 @@ class MatrixFederationRequest:
)
object.__setattr__(self, "uri", uri)
- def get_json(self):
+ def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
return self.json_callback()
return self.json
@@ -148,7 +149,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
-):
+) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@@ -160,7 +161,7 @@ async def _handle_json_response(
start_ms: Timestamp when request was made
Returns:
- dict: parsed JSON response
+ The parsed JSON response
"""
try:
check_content_type_is_json(response.headers)
@@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
- self.reactor,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -266,27 +265,29 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
async def _send_request_with_optional_trailing_slash(
- self, request, try_trailing_slash_on_400=False, **send_request_args
- ):
+ self,
+ request: MatrixFederationRequest,
+ try_trailing_slash_on_400: bool = False,
+ **send_request_args
+ ) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
Args:
- request (MatrixFederationRequest): details of request to be sent
- try_trailing_slash_on_400 (bool): Whether on receiving a 400
+ request: details of request to be sent
+ try_trailing_slash_on_400: Whether on receiving a 400
'M_UNRECOGNIZED' from the server to retry the request with a
trailing slash appended to the request path.
- send_request_args (Dict): A dictionary of arguments to pass to
- `_send_request()`.
+ send_request_args: A dictionary of arguments to pass to `_send_request()`.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
Returns:
- Dict: Parsed JSON response body.
+ Parsed JSON response body.
"""
try:
response = await self._send_request(request, **send_request_args)
@@ -313,24 +314,26 @@ class MatrixFederationHttpClient:
async def _send_request(
self,
- request,
- retry_on_dns_fail=True,
- timeout=None,
- long_retries=False,
- ignore_backoff=False,
- backoff_on_404=False,
- ):
+ request: MatrixFederationRequest,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ long_retries: bool = False,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ ) -> IResponse:
"""
Sends a request to the given server.
Args:
- request (MatrixFederationRequest): details of request to be sent
+ request: details of request to be sent
+
+ retry_on_dns_fail: true if the request should be retied on DNS failures
- timeout (int|None): number of milliseconds to wait for the response headers
+ timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
60s by default.
- long_retries (bool): whether to use the long retry algorithm.
+ long_retries: whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
@@ -346,14 +349,13 @@ class MatrixFederationHttpClient:
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): Back off if we get a 404
+ backoff_on_404: Back off if we get a 404
Returns:
- twisted.web.client.Response: resolves with the HTTP
- response object on success.
+ Resolves with the HTTP response object on success.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
@@ -404,7 +406,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
- headers_dict = {}
+ headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -435,7 +437,7 @@ class MatrixFederationHttpClient:
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
- )
+ ) # type: Optional[IBodyProducer]
else:
producer = None
auth_headers = self.build_auth_headers(
@@ -524,14 +526,16 @@ class MatrixFederationHttpClient:
)
body = None
- e = HttpResponseException(response.code, response_phrase, body)
+ exc = HttpResponseException(
+ response.code, response_phrase, body
+ )
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
- raise RequestSendFailed(e, can_retry=True) from e
+ raise RequestSendFailed(exc, can_retry=True) from exc
else:
- raise e
+ raise exc
break
except RequestSendFailed as e:
@@ -582,22 +586,27 @@ class MatrixFederationHttpClient:
return response
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None
- ):
+ self,
+ destination: Optional[bytes],
+ method: bytes,
+ url_bytes: bytes,
+ content: Optional[JsonDict] = None,
+ destination_is: Optional[bytes] = None,
+ ) -> List[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The desination homeserver of the request.
+ destination: The destination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
- method (bytes): The HTTP method of the request
- url_bytes (bytes): The URI path of the request
- content (object): The body of the request
- destination_is (bytes): As 'destination', but if the destination is an
+ method: The HTTP method of the request
+ url_bytes: The URI path of the request
+ content: The body of the request
+ destination_is: As 'destination', but if the destination is an
identity server
Returns:
- list[bytes]: a list of headers to be added as "Authorization:" headers
+ A list of headers to be added as "Authorization:" headers
"""
request = {
"method": method.decode("ascii"),
@@ -629,33 +638,32 @@ class MatrixFederationHttpClient:
async def put_json(
self,
- destination,
- path,
- args={},
- data={},
- json_data_callback=None,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False,
- ):
- """ Sends the specifed json data using PUT
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
+ """ Sends the specified json data using PUT
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
- args (dict): query params
- data (dict): A dict containing the data that will be used as
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
+ args: query params
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- json_data_callback (callable): A callable returning the dict to
+ json_data_callback: A callable returning the dict to
use as the request body.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -663,19 +671,19 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): True if we should count a 404 response as
+ backoff_on_404: True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests).
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -721,29 +729,28 @@ class MatrixFederationHttpClient:
async def post_json(
self,
- destination,
- path,
- data={},
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
- """ Sends the specifed json data using POST
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
+ """ Sends the specified json data using POST
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- data (dict): A dict containing the data that will be used as
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -751,10 +758,10 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -795,26 +802,25 @@ class MatrixFederationHttpClient:
async def get_json(
self,
- destination,
- path,
- args=None,
- retry_on_dns_fail=True,
- timeout=None,
- ignore_backoff=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- args (dict|None): A dictionary used to create query strings, defaults to
+ args: A dictionary used to create query strings, defaults to
None.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -822,14 +828,14 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -870,24 +876,23 @@ class MatrixFederationHttpClient:
async def delete_json(
self,
- destination,
- path,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -895,12 +900,12 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -938,25 +943,25 @@ class MatrixFederationHttpClient:
async def get_file(
self,
- destination,
- path,
+ destination: str,
+ path: str,
output_stream,
- args={},
- retry_on_dns_fail=True,
- max_size=None,
- ignore_backoff=False,
- ):
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ max_size: Optional[int] = None,
+ ignore_backoff: bool = False,
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
- destination (str): The remote server to send the HTTP request to.
- path (str): The HTTP path to GET.
- output_stream (file): File to write the response body to.
- args (dict): Optional dictionary used to create the query string.
- ignore_backoff (bool): true to ignore the historical backoff data
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path to GET.
+ output_stream: File to write the response body to.
+ args: Optional dictionary used to create the query string.
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- tuple[int, dict]: Resolves with an (int,dict) tuple of
+ Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@@ -980,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = _readBodyToFile(response, output_stream, max_size)
+ d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
-class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
- self.stream = stream
- self.deferred = deferred
- self.length = 0
- self.max_size = max_size
-
- def dataReceived(self, data):
- self.stream.write(data)
- self.length += len(data)
- if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
- self.transport.loseConnection()
-
- def connectionLost(self, reason):
- if reason.check(ResponseDone):
- self.deferred.callback(self.length)
- else:
- self.deferred.errback(reason)
-
-
-def _readBodyToFile(response, stream, max_size):
- d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
- return d
-
-
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e):
return repr(e)
-def check_content_type_is_json(headers):
+def check_content_type_is_json(headers: Headers) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
- headers (twisted.web.http_headers.Headers): headers to check
+ headers: headers to check
Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
@@ -1063,27 +1034,18 @@ def check_content_type_is_json(headers):
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
- raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False)
+ raise RequestSendFailed(
+ RuntimeError("No Content-Type header received from remote server"),
+ can_retry=False,
+ )
c_type = c_type[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(
- RuntimeError("Content-Type not application/json: was '%s'" % c_type),
+ RuntimeError(
+ "Remote server sent Content-Type header of '%s', not 'application/json'"
+ % c_type,
+ ),
can_retry=False,
)
-
-
-def encode_query_args(args):
- if args is None:
- return b""
-
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, str):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.parse.urlencode(encoded_args, True)
-
- return query_bytes.encode("utf8")
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index cd94e789e8..7c5defec82 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -109,7 +109,7 @@ in_flight_requests_db_sched_duration = Counter(
# The set of all in flight requests, set[RequestMetrics]
_in_flight_requests = set()
-# Protects the _in_flight_requests set from concurrent accesss
+# Protects the _in_flight_requests set from concurrent access
_in_flight_requests_lock = threading.Lock()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..6a4e429a6c 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
-from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer
from twisted.web.util import redirectTo
-import synapse.events
-import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -96,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request,
- error_code,
- error_dict,
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
+ request, error_code, error_dict, send_cors=True,
)
@@ -182,7 +176,7 @@ class HttpServer:
""" Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
- If the regex contains groups these gets passed to the calback via
+ If the regex contains groups these gets passed to the callback via
an unpacked tuple.
Args:
@@ -241,7 +235,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
async def _async_render(self, request: Request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
- no appropriate method exists. Can be overriden in sub classes for
+ no appropriate method exists. Can be overridden in sub classes for
different routing.
"""
# Treat HEAD requests as GET requests.
@@ -257,7 +251,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -292,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@@ -386,7 +379,7 @@ class JsonResource(DirectServeJsonResource):
async def _async_render(self, request):
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- # Make sure we have an appopriate name for this handler in prometheus
+ # Make sure we have an appropriate name for this handler in prometheus
# (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
@@ -406,7 +399,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -589,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
- pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@@ -600,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
- pretty_print: Whether to include indentation and line-breaks in the
- resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@@ -617,13 +607,10 @@ def respond_with_json(
)
return None
- if pretty_print:
- encoder = iterencode_pretty_printed_json
+ if canonical_json:
+ encoder = iterencode_canonical_json
else:
- if canonical_json or synapse.events.USE_FROZEN_DICTS:
- encoder = iterencode_canonical_json
- else:
- encoder = _encode_json_bytes
+ encoder = _encode_json_bytes
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -651,6 +638,11 @@ def respond_with_json_bytes(
Returns:
twisted.web.server.NOT_DONE_YET if the request is still active.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -682,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@@ -756,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
-
-
-def _request_user_agent_is_curl(request: Request) -> bool:
- user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
- for user_agent in user_agents:
- if b"curl" in user_agent:
- return True
- return False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd90ba7828..b361b7cbaf 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -272,7 +272,6 @@ class RestServlet:
on_PUT
on_POST
on_DELETE
- on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
@@ -283,7 +282,7 @@ class RestServlet:
if hasattr(self, "PATTERNS"):
patterns = self.PATTERNS
- for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
+ for method in ("GET", "PUT", "POST", "DELETE"):
if hasattr(self, "on_%s" % (method,)):
servlet_classname = self.__class__.__name__
method_handler = getattr(self, "on_%s" % (method,))
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 6e79b47828..5f0581dc3f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Optional
+from typing import Optional, Union
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -54,9 +55,12 @@ class SynapseRequest(Request):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self._channel = channel # this is used by the tests
- self.authenticated_entity = None
self.start_time = 0.0
+ # The requester, if authenticated. For federation requests this is the
+ # server name, for client requests this is the Requester object.
+ self.requester = None # type: Optional[Union[Requester, str]]
+
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext]
@@ -109,8 +113,14 @@ class SynapseRequest(Request):
method = self.method.decode("ascii")
return method
- def get_user_agent(self):
- return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
+ def get_user_agent(self, default: str) -> str:
+ """Return the last User-Agent header, or the given default.
+ """
+ user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
+ if user_agent is None:
+ return default
+
+ return user_agent.decode("ascii", "replace")
def render(self, resrc):
# this is called once a Resource has been found to serve the request; in our
@@ -161,7 +171,9 @@ class SynapseRequest(Request):
yield
except Exception:
# this should already have been caught, and sent back to the client as a 500.
- logger.exception("Asynchronous messge handler raised an uncaught exception")
+ logger.exception(
+ "Asynchronous message handler raised an uncaught exception"
+ )
finally:
# the request handler has finished its work and either sent the whole response
# back, or handed over responsibility to a Producer.
@@ -263,22 +275,30 @@ class SynapseRequest(Request):
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
- # need to decode as it could be raw utf-8 bytes
- # from a IDN servname in an auth header
- authenticated_entity = self.authenticated_entity
- if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
- authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+ # Convert the requester into a string that we can log
+ authenticated_entity = None
+ if isinstance(self.requester, str):
+ authenticated_entity = self.requester
+ elif isinstance(self.requester, Requester):
+ authenticated_entity = self.requester.authenticated_entity
+
+ # If this is a request where the target user doesn't match the user who
+ # authenticated (e.g. and admin is puppetting a user) then we log both.
+ if self.requester.user.to_string() != authenticated_entity:
+ authenticated_entity = "{},{}".format(
+ authenticated_entity, self.requester.user.to_string(),
+ )
+ elif self.requester is not None:
+ # This shouldn't happen, but we log it so we don't lose information
+ # and can see that we're doing something wrong.
+ authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
- user_agent = self.get_user_agent()
- if user_agent is not None:
- user_agent = user_agent.decode("utf-8", "replace")
- else:
- user_agent = "-"
+ user_agent = self.get_user_agent("-")
code = str(self.code)
if not self.finished:
|