diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 4bc3cb53f0..c658862fe6 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
def get_request_user_agent(request: IRequest, default: str = "") -> str:
- """Return the last User-Agent header, or the given default.
- """
+ """Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37ccf5ab98..a910548f1e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -56,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -289,8 +289,7 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
- http_proxy: Optional[bytes] = None,
- https_proxy: Optional[bytes] = None,
+ use_proxy: bool = False,
):
"""
Args:
@@ -300,8 +299,8 @@ class SimpleHttpClient:
we may not request.
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy: proxy server to use for http connections. host[:port]
- https_proxy: proxy server to use for https connections. host[:port]
+ use_proxy: Whether proxy settings should be discovered and used
+ from conventional environment variables.
"""
self.hs = hs
@@ -345,8 +344,7 @@ class SimpleHttpClient:
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
- http_proxy=http_proxy,
- https_proxy=https_proxy,
+ use_proxy=use_proxy,
)
if self._ip_blacklist:
@@ -398,7 +396,8 @@ class SimpleHttpClient:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(
- BytesIO(data), cooperator=self._cooperator,
+ BytesIO(data),
+ cooperator=self._cooperator,
)
request_deferred = treq.request(
@@ -407,13 +406,18 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
+ # Avoid buffering the body in treq since we do not reuse
+ # response bodies.
+ unbuffered=True,
**self._extra_treq_args,
) # type: defer.Deferred
# we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
- request_deferred, 60, self.hs.get_reactor(),
+ request_deferred,
+ 60,
+ self.hs.get_reactor(),
)
# turn timeouts into RequestTimedOutErrors
@@ -699,18 +703,6 @@ class SimpleHttpClient:
resp_headers = dict(response.headers.getAllRawHeaders())
- if (
- b"Content-Length" in resp_headers
- and max_size
- and int(resp_headers[b"Content-Length"][0]) > max_size
- ):
- logger.warning("Requested URL is too large > %r bytes" % (max_size,))
- raise SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (max_size,),
- Codes.TOO_LARGE,
- )
-
if response.code > 299:
logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
@@ -777,7 +769,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
# in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
- self.transport.loseConnection()
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ self.transport.abortConnection()
def connectionLost(self, reason: Failure) -> None:
# If the maximum size was already exceeded, there's nothing to do.
@@ -811,6 +805,11 @@ def read_body_with_max_size(
Returns:
A Deferred which resolves to the length of the read body.
"""
+ # If the Content-Length header gives a size larger than the maximum allowed
+ # size, do not bother downloading the body.
+ if max_size is not None and response.length != UNKNOWN_LENGTH:
+ if response.length > max_size:
+ return defer.fail(BodyExceededMaxSize())
d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.internet.protocol import connectionDone
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.web import http
+from twisted.web.http_headers import Headers
logger = logging.getLogger(__name__)
@@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint:
Args:
reactor: the Twisted reactor to use for the connection
- proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
- proxy
- host (bytes): hostname that we want to CONNECT to
- port (int): port that we want to connect to
+ proxy_endpoint: the endpoint to use to connect to the proxy
+ host: hostname that we want to CONNECT to
+ port: port that we want to connect to
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, reactor, proxy_endpoint, host, port):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ proxy_endpoint: IStreamClientEndpoint,
+ host: bytes,
+ port: int,
+ headers: Headers,
+ ):
self._reactor = reactor
self._proxy_endpoint = proxy_endpoint
self._host = host
self._port = port
+ self._headers = headers
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
- def connect(self, protocolFactory):
- f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ def connect(self, protocolFactory: ClientFactory):
+ f = HTTPProxiedClientFactory(
+ self._host, self._port, protocolFactory, self._headers
+ )
d = self._proxy_endpoint.connect(f)
# once the tcp socket connects successfully, we need to wait for the
# CONNECT to complete.
@@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
HTTP Protocol object and run the rest of the connection.
Args:
- dst_host (bytes): hostname that we want to CONNECT to
- dst_port (int): port that we want to connect to
- wrapped_factory (protocol.ClientFactory): The original Factory
+ dst_host: hostname that we want to CONNECT to
+ dst_port: port that we want to connect to
+ wrapped_factory: The original Factory
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, dst_host, dst_port, wrapped_factory):
+ def __init__(
+ self,
+ dst_host: bytes,
+ dst_port: int,
+ wrapped_factory: ClientFactory,
+ headers: Headers,
+ ):
self.dst_host = dst_host
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
+ self.headers = headers
self.on_connection = defer.Deferred()
def startedConnecting(self, connector):
@@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
return HTTPConnectProtocol(
- self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ self.dst_host,
+ self.dst_port,
+ wrapped_protocol,
+ self.on_connection,
+ self.headers,
)
def clientConnectionFailed(self, connector, reason):
@@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol):
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
Args:
- host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
to put in the CONNECT request
- port (int): The original HTTP(s) port to put in the CONNECT request
+ port: The original HTTP(s) port to put in the CONNECT request
- wrapped_protocol (interfaces.IProtocol): the original protocol (probably
- HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+ wrapped_protocol: the original protocol (probably HTTPChannel or
+ TLSMemoryBIOProtocol, but could be anything really)
- connected_deferred (Deferred): a Deferred which will be callbacked with
+ connected_deferred: a Deferred which will be callbacked with
wrapped_protocol when the CONNECT completes
+
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ def __init__(
+ self,
+ host: bytes,
+ port: int,
+ wrapped_protocol: Protocol,
+ connected_deferred: defer.Deferred,
+ headers: Headers,
+ ):
self.host = host
self.port = port
self.wrapped_protocol = wrapped_protocol
self.connected_deferred = connected_deferred
- self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.headers = headers
+
+ self.http_setup_client = HTTPConnectSetupClient(
+ self.host, self.port, self.headers
+ )
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
def connectionMade(self):
@@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if buf:
self.wrapped_protocol.dataReceived(buf)
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes):
# if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data)
@@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient):
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
Args:
- host (bytes): The hostname to send in the CONNECT message
- port (int): The port to send in the CONNECT message
+ host: The hostname to send in the CONNECT message
+ port: The port to send in the CONNECT message
+ headers: Extra headers to send with the CONNECT message
"""
- def __init__(self, host, port):
+ def __init__(self, host: bytes, port: int, headers: Headers):
self.host = host
self.port = port
+ self.headers = headers
self.on_connected = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+
+ # Send any additional specified headers
+ for name, values in self.headers.getAllRawHeaders():
+ for value in values:
+ self.sendHeader(name, value)
+
self.endHeaders()
- def handleStatus(self, version, status, message):
+ def handleStatus(self, version: bytes, status: bytes, message: bytes):
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 4c06a117d3..2e83fa6773 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -195,8 +195,7 @@ class MatrixFederationAgent:
@implementer(IAgentEndpointFactory)
class MatrixHostnameEndpointFactory:
- """Factory for MatrixHostnameEndpoint for parsing to an Agent.
- """
+ """Factory for MatrixHostnameEndpoint for parsing to an Agent."""
def __init__(
self,
@@ -261,8 +260,7 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
- """Implements IStreamClientEndpoint interface
- """
+ """Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)
@@ -323,12 +321,19 @@ class MatrixHostnameEndpoint:
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
+ logger.debug("Looking up SRV record for %s", host.decode(errors="replace"))
server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
+ logger.debug(
+ "Got %s from SRV lookup for %s",
+ ", ".join(map(str, server_list)),
+ host.decode(errors="replace"),
+ )
return server_list
# No SRV records, so we fallback to host and 8448
+ logger.debug("No SRV records for %s", host.decode(errors="replace"))
return [Server(host, 8448)]
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index b3b6dbcab0..4def7d7633 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -81,8 +81,7 @@ class WellKnownLookupResult:
class WellKnownResolver:
- """Handles well-known lookups for matrix servers.
- """
+ """Handles well-known lookups for matrix servers."""
def __init__(
self,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 19293bf673..cde42e9f5e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -254,7 +254,8 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
+ self.agent,
+ ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -652,7 +653,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
- """ Sends the specified json data using PUT
+ """Sends the specified json data using PUT
Args:
destination: The remote server to send the HTTP request to.
@@ -740,7 +741,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
- """ Sends the specified json data using POST
+ """Sends the specified json data using POST
Args:
destination: The remote server to send the HTTP request to.
@@ -799,7 +800,11 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms,
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
)
return body
@@ -813,7 +818,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
- """ GETs some json from the given host homeserver and path
+ """GETs some json from the given host homeserver and path
Args:
destination: The remote server to send the HTTP request to.
@@ -994,7 +999,10 @@ class MatrixFederationHttpClient:
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (max_size,)
logger.warning(
- "{%s} [%s] %s", request.txn_id, request.destination, msg,
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
)
raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b730d2c634..ee65a6668b 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,9 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
import re
+from typing import Optional, Tuple
+from urllib.request import getproxies_environment, proxy_bypass_environment
+import attr
from zope.interface import implementer
from twisted.internet import defer
@@ -22,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported
+from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@@ -31,6 +36,22 @@ logger = logging.getLogger(__name__)
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+@attr.s
+class ProxyCredentials:
+ username_password = attr.ib(type=bytes)
+
+ def as_proxy_authorization_value(self) -> bytes:
+ """
+ Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
+
+ Returns:
+ A transformation of the authentication string the encoded value for
+ a Proxy-Authorization header.
+ """
+ # Encode as base64 and prepend the authorization type
+ return b"Basic " + base64.encodebytes(self.username_password)
+
+
@implementer(IAgent)
class ProxyAgent(_AgentBase):
"""An Agent implementation which will use an HTTP proxy if one was requested
@@ -58,6 +79,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created.
+
+ use_proxy (bool): Whether proxy settings should be discovered and used
+ from conventional environment variables.
"""
def __init__(
@@ -68,8 +92,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None,
bindAddress=None,
pool=None,
- http_proxy=None,
- https_proxy=None,
+ use_proxy=False,
):
_AgentBase.__init__(self, reactor, pool)
@@ -84,6 +107,18 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress
+ http_proxy = None
+ https_proxy = None
+ no_proxy = None
+ if use_proxy:
+ proxies = getproxies_environment()
+ http_proxy = proxies["http"].encode() if "http" in proxies else None
+ https_proxy = proxies["https"].encode() if "https" in proxies else None
+ no_proxy = proxies["no"] if "no" in proxies else None
+
+ # Parse credentials from https proxy connection string if present
+ self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
+
self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
@@ -92,6 +127,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
+ self.no_proxy = no_proxy
+
self._policy_for_https = contextFactory
self._reactor = reactor
@@ -139,18 +176,43 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm
- if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ should_skip_proxy = False
+ if self.no_proxy is not None:
+ should_skip_proxy = proxy_bypass_environment(
+ parsed_uri.host.decode(), proxies={"no": self.no_proxy},
+ )
+
+ if (
+ parsed_uri.scheme == b"http"
+ and self.http_proxy_endpoint
+ and not should_skip_proxy
+ ):
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint
request_path = uri
- elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ elif (
+ parsed_uri.scheme == b"https"
+ and self.https_proxy_endpoint
+ and not should_skip_proxy
+ ):
+ connect_headers = Headers()
+
+ # Determine whether we need to set Proxy-Authorization headers
+ if self.https_proxy_creds:
+ # Set a Proxy-Authorization header
+ connect_headers.addRawHeader(
+ b"Proxy-Authorization",
+ self.https_proxy_creds.as_proxy_authorization_value(),
+ )
+
endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor,
self.https_proxy_endpoint,
parsed_uri.host,
parsed_uri.port,
+ headers=connect_headers,
)
else:
# not using a proxy
@@ -179,12 +241,16 @@ class ProxyAgent(_AgentBase):
)
-def _http_proxy_endpoint(proxy, reactor, **kwargs):
+def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
"""Parses an http proxy setting and returns an endpoint for the proxy
Args:
- proxy (bytes|None): the proxy setting
+ proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
+ Note that compared to other apps, this function currently lacks support
+ for specifying a protocol schema (i.e. protocol://...).
+
reactor: reactor to be used to connect to the proxy
+
kwargs: other args to be passed to HostnameEndpoint
Returns:
@@ -194,16 +260,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs):
if proxy is None:
return None
- # currently we only support hostname:port. Some apps also support
- # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
- # proxy.
-
+ # Parse the connection string
host, port = parse_host_port(proxy, default_port=1080)
return HostnameEndpoint(reactor, host, port, **kwargs)
-def parse_host_port(hostport, default_port=None):
- # could have sworn we had one of these somewhere else...
+def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
+ """
+ Parses the username and password from a proxy declaration e.g
+ username:password@hostname:port.
+
+ Args:
+ proxy: The proxy connection string.
+
+ Returns
+ An instance of ProxyCredentials and the proxy connection string with any credentials
+ stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
+ ProxyCredentials instance is replaced with None.
+ """
+ if proxy and b"@" in proxy:
+ # We use rsplit here as the password could contain an @ character
+ credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
+ return ProxyCredentials(credentials), proxy_without_credentials
+
+ return None, proxy
+
+
+def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
+ """
+ Parse the hostname and port from a proxy connection byte string.
+
+ Args:
+ hostport: The proxy connection string. Must be in the form 'host[:port]'.
+ default_port: The default port to return if one is not found in `hostport`.
+
+ Returns:
+ A tuple containing the hostname and port. Uses `default_port` if one was not found.
+ """
if b":" in hostport:
host, port = hostport.rsplit(b":", 1)
try:
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 7c5defec82..0ec5d941b8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -213,8 +213,7 @@ class RequestMetrics:
self.update_metrics()
def update_metrics(self):
- """Updates the in flight metrics with values from this request.
- """
+ """Updates the in flight metrics with values from this request."""
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 8249732b27..845db9b78d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -76,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
- """Sends a JSON error response to clients.
- """
+ """Sends a JSON error response to clients."""
if f.check(SynapseError):
error_code = f.value.code
@@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request, error_code, error_dict, send_cors=True,
+ request,
+ error_code,
+ error_dict,
+ send_cors=True,
)
def return_html_error(
- f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+ f: failure.Failure,
+ request: Request,
+ error_template: Union[str, jinja2.Template],
) -> None:
"""Sends an HTML error page corresponding to the given failure.
@@ -189,8 +193,7 @@ ServletCallback = Callable[
class HttpServer(Protocol):
- """ Interface for registering callbacks on a HTTP server
- """
+ """Interface for registering callbacks on a HTTP server"""
def register_paths(
self,
@@ -199,7 +202,7 @@ class HttpServer(Protocol):
callback: ServletCallback,
servlet_classname: str,
) -> None:
- """ Register a callback that gets fired if we receive a http request
+ """Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
If the regex contains groups these gets passed to the callback via
@@ -235,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context
def render(self, request):
- """ This gets called by twisted every time someone sends us a request.
- """
+ """This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@@ -287,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod
def _send_response(
- self, request: SynapseRequest, code: int, response_object: Any,
+ self,
+ request: SynapseRequest,
+ code: int,
+ response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
raise NotImplementedError()
@@ -308,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
self.canonical_json = canonical_json
def _send_response(
- self, request: Request, code: int, response_object: Any,
+ self,
+ request: Request,
+ code: int,
+ response_object: Any,
):
- """Implements _AsyncResource._send_response
- """
+ """Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
@@ -322,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource):
)
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
- """Implements _AsyncResource._send_error_response
- """
+ """Implements _AsyncResource._send_error_response"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
- """ This implements the HttpServer interface and provides JSON support for
+ """This implements the HttpServer interface and provides JSON support for
Resources.
Register callbacks via register_paths()
@@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
- self, request: SynapseRequest, code: int, response_object: Any,
+ self,
+ request: SynapseRequest,
+ code: int,
+ response_object: Any,
):
- """Implements _AsyncResource._send_response
- """
+ """Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
@@ -454,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
- """Implements _AsyncResource._send_error_response
- """
+ """Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE)
@@ -534,7 +547,9 @@ class _ByteProducer:
min_chunk_size = 1024
def __init__(
- self, request: Request, iterator: Iterator[bytes],
+ self,
+ request: Request,
+ iterator: Iterator[bytes],
):
self._request = request
self._iterator = iterator
@@ -654,7 +669,10 @@ def respond_with_json(
def respond_with_json_bytes(
- request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
+ request: Request,
+ code: int,
+ json_bytes: bytes,
+ send_cors: bool = False,
):
"""Sends encoded JSON in response to the given request.
@@ -769,7 +787,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
def finish_request(request: Request):
- """ Finish writing the response to the request.
+ """Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
response was written but doesn't provide a convenient or reliable way to
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b361b7cbaf..839d58d0d4 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,8 +14,8 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
-
import logging
+from typing import Dict, List, Optional, Union
from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
@@ -147,16 +147,67 @@ def parse_string(
)
+def parse_list_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: Union[bytes, str],
+ encoding: Optional[str] = "ascii",
+):
+ """Parse and optionally decode a list of values from request query parameters.
+
+ Args:
+ args: A dictionary of query parameters from a request.
+ name: The name of the query parameter to extract values from. If given as bytes,
+ will be decoded as "ascii".
+ encoding: An optional encoding that is used to decode each parameter value with.
+
+ Raises:
+ KeyError: If the given `name` does not exist in `args`.
+ SynapseError: If an argument was not encoded with the specified `encoding`.
+ """
+ if not isinstance(name, bytes):
+ name = name.encode("ascii")
+ args_list = args[name]
+
+ if encoding:
+ # Decode each argument value
+ try:
+ args_list = [value.decode(encoding) for value in args_list]
+ except ValueError:
+ raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+
+ return args_list
+
+
def parse_string_from_args(
- args,
- name,
- default=None,
- required=False,
- allowed_values=None,
- param_type="string",
- encoding="ascii",
+ args: Dict[bytes, List[bytes]],
+ name: Union[bytes, str],
+ default: Optional[str] = None,
+ required: Optional[bool] = False,
+ allowed_values: Optional[List[bytes]] = None,
+ param_type: Optional[str] = "string",
+ encoding: Optional[str] = "ascii",
):
+ """Parse and optionally decode a single value from request query parameters.
+ Args:
+ args: A dictionary of query parameters from a request.
+ name: The name of the query parameter to extract values from. If given as bytes,
+ will be decoded as "ascii".
+ default: A default value to return if the given argument `name` was not found.
+ required: If this is True, no `default` is provided and the given argument `name`
+ was not found then a SynapseError is raised.
+ allowed_values: A list of allowed values. If specified and the found str is
+ not in this list, a SynapseError is raised.
+ param_type: The expected type of the query parameter's value.
+ encoding: An optional encoding that is used to decode each parameter value with.
+
+ Returns:
+ The found argument value.
+
+ Raises:
+ SynapseError: If the given name was not found in the request arguments,
+ the argument's values were encoded incorrectly or a required value was missing.
+ """
if not isinstance(name, bytes):
name = name.encode("ascii")
@@ -258,7 +309,7 @@ def assert_params_in_dict(body, required):
class RestServlet:
- """ A Synapse REST Servlet.
+ """A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 12ec3f851f..4a4fb5ef26 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -249,8 +249,7 @@ class SynapseRequest(Request):
)
def _finished_processing(self):
- """Log the completion of this request and update the metrics
- """
+ """Log the completion of this request and update the metrics"""
assert self.logcontext is not None
usage = self.logcontext.get_resource_usage()
@@ -276,7 +275,8 @@ class SynapseRequest(Request):
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
- authenticated_entity, self.requester.user.to_string(),
+ authenticated_entity,
+ self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
@@ -322,8 +322,7 @@ class SynapseRequest(Request):
logger.warning("Failed to stop metrics: %r", e)
def _should_log_request(self) -> bool:
- """Whether we should log at INFO that we processed the request.
- """
+ """Whether we should log at INFO that we processed the request."""
if self.path == b"/health":
return False
|