summary refs log tree commit diff
path: root/synapse/http/matrixfederationclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/matrixfederationclient.py')
-rw-r--r--synapse/http/matrixfederationclient.py74
1 files changed, 10 insertions, 64 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b2ccae90df..4e27f93b7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,7 @@ import random
 import sys
 import urllib.parse
 from io import BytesIO
-from typing import BinaryIO, Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
 
 import attr
 import treq
@@ -28,26 +28,27 @@ from prometheus_client import Counter
 from signedjson.sign import sign_json
 from zope.interface import implementer
 
-from twisted.internet import defer, protocol
+from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
 
 import synapse.metrics
 import synapse.util.retryutils
 from synapse.api.errors import (
-    Codes,
     FederationDeniedError,
     HttpResponseException,
     RequestSendFailed,
-    SynapseError,
 )
 from synapse.http import QuieterFileBodyProducer
-from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
+from synapse.http.client import (
+    BlacklistingAgentWrapper,
+    IPBlacklistingResolver,
+    encode_query_args,
+    readBodyToFile,
+)
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import (
@@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
         # blacklist via IP literals in server names
         self.agent = BlacklistingAgentWrapper(
-            self.agent,
-            self.reactor,
-            ip_blacklist=hs.config.federation_ip_range_blacklist,
+            self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
         )
 
         self.clock = hs.get_clock()
@@ -986,7 +985,7 @@ class MatrixFederationHttpClient:
         headers = dict(response.headers.getAllRawHeaders())
 
         try:
-            d = _readBodyToFile(response, output_stream, max_size)
+            d = readBodyToFile(response, output_stream, max_size)
             d.addTimeout(self.default_timeout, self.reactor)
             length = await make_deferred_yieldable(d)
         except Exception as e:
@@ -1010,44 +1009,6 @@ class MatrixFederationHttpClient:
         return (length, headers)
 
 
-class _ReadBodyToFileProtocol(protocol.Protocol):
-    def __init__(
-        self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
-    ):
-        self.stream = stream
-        self.deferred = deferred
-        self.length = 0
-        self.max_size = max_size
-
-    def dataReceived(self, data: bytes) -> None:
-        self.stream.write(data)
-        self.length += len(data)
-        if self.max_size is not None and self.length >= self.max_size:
-            self.deferred.errback(
-                SynapseError(
-                    502,
-                    "Requested file is too large > %r bytes" % (self.max_size,),
-                    Codes.TOO_LARGE,
-                )
-            )
-            self.deferred = defer.Deferred()
-            self.transport.loseConnection()
-
-    def connectionLost(self, reason: Failure) -> None:
-        if reason.check(ResponseDone):
-            self.deferred.callback(self.length)
-        else:
-            self.deferred.errback(reason)
-
-
-def _readBodyToFile(
-    response: IResponse, stream: BinaryIO, max_size: Optional[int]
-) -> defer.Deferred:
-    d = defer.Deferred()
-    response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
-    return d
-
-
 def _flatten_response_never_received(e):
     if hasattr(e, "reasons"):
         reasons = ", ".join(
@@ -1088,18 +1049,3 @@ def check_content_type_is_json(headers: Headers) -> None:
             ),
             can_retry=False,
         )
-
-
-def encode_query_args(args: Optional[QueryArgs]) -> bytes:
-    if args is None:
-        return b""
-
-    encoded_args = {}
-    for k, vs in args.items():
-        if isinstance(vs, str):
-            vs = [vs]
-        encoded_args[k] = [v.encode("utf8") for v in vs]
-
-    query_str = urllib.parse.urlencode(encoded_args, True)
-
-    return query_str.encode("utf8")