diff options
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r-- | synapse/http/endpoint.py | 70 |
1 files changed, 66 insertions, 4 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 442696d393..8c64339a7c 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -14,7 +14,7 @@ # limitations under the License. from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint -from twisted.internet import defer +from twisted.internet import defer, reactor from twisted.internet.error import ConnectError from twisted.names import client, dns from twisted.names.error import DNSNameError, DomainError @@ -66,13 +66,75 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, default_port = 8448 if port is None: - return SRVClientEndpoint( + return _WrappingEndpointFac(SRVClientEndpoint( reactor, "matrix", domain, protocol="tcp", default_port=default_port, endpoint=transport_endpoint, endpoint_kw_args=endpoint_kw_args - ) + )) else: - return transport_endpoint(reactor, domain, port, **endpoint_kw_args) + return _WrappingEndpointFac(transport_endpoint( + reactor, domain, port, **endpoint_kw_args + )) + + +class _WrappingEndpointFac(object): + def __init__(self, endpoint_fac): + self.endpoint_fac = endpoint_fac + + @defer.inlineCallbacks + def connect(self, protocolFactory): + conn = yield self.endpoint_fac.connect(protocolFactory) + conn = _WrappedConnection(conn) + defer.returnValue(conn) + + +class _WrappedConnection(object): + """Wraps a connection and calls abort on it if it hasn't seen any action + for 2.5-3 minutes. + """ + __slots__ = ["conn", "last_request"] + + def __init__(self, conn): + object.__setattr__(self, "conn", conn) + object.__setattr__(self, "last_request", time.time()) + + def __getattr__(self, name): + return getattr(self.conn, name) + + def __setattr__(self, name, value): + setattr(self.conn, name, value) + + def _time_things_out_maybe(self): + # We use a slightly shorter timeout here just in case the callLater is + # triggered early. Paranoia ftw. + # TODO: Cancel the previous callLater rather than comparing time.time()? + if time.time() - self.last_request >= 2.5 * 60: + self.abort() + # Abort the underlying TLS connection. The abort() method calls + # loseConnection() on the underlying TLS connection which tries to + # shutdown the connection cleanly. We call abortConnection() + # since that will promptly close the underlying TCP connection. + self.transport.abortConnection() + + def request(self, request): + self.last_request = time.time() + + # Time this connection out if we haven't send a request in the last + # N minutes + # TODO: Cancel the previous callLater? + reactor.callLater(3 * 60, self._time_things_out_maybe) + + d = self.conn.request(request) + + def update_request_time(res): + self.last_request = time.time() + # TODO: Cancel the previous callLater? + reactor.callLater(3 * 60, self._time_things_out_maybe) + return res + + d.addCallback(update_request_time) + + return d class SpiderEndpoint(object): |