summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r--synapse/http/endpoint.py70
1 files changed, 66 insertions, 4 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 442696d393..8c64339a7c 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from twisted.internet.error import ConnectError
 from twisted.names import client, dns
 from twisted.names.error import DNSNameError, DomainError
@@ -66,13 +66,75 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
         default_port = 8448
 
     if port is None:
-        return SRVClientEndpoint(
+        return _WrappingEndpointFac(SRVClientEndpoint(
             reactor, "matrix", domain, protocol="tcp",
             default_port=default_port, endpoint=transport_endpoint,
             endpoint_kw_args=endpoint_kw_args
-        )
+        ))
     else:
-        return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
+        return _WrappingEndpointFac(transport_endpoint(
+            reactor, domain, port, **endpoint_kw_args
+        ))
+
+
+class _WrappingEndpointFac(object):
+    def __init__(self, endpoint_fac):
+        self.endpoint_fac = endpoint_fac
+
+    @defer.inlineCallbacks
+    def connect(self, protocolFactory):
+        conn = yield self.endpoint_fac.connect(protocolFactory)
+        conn = _WrappedConnection(conn)
+        defer.returnValue(conn)
+
+
+class _WrappedConnection(object):
+    """Wraps a connection and calls abort on it if it hasn't seen any action
+    for 2.5-3 minutes.
+    """
+    __slots__ = ["conn", "last_request"]
+
+    def __init__(self, conn):
+        object.__setattr__(self, "conn", conn)
+        object.__setattr__(self, "last_request", time.time())
+
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+    def __setattr__(self, name, value):
+        setattr(self.conn, name, value)
+
+    def _time_things_out_maybe(self):
+        # We use a slightly shorter timeout here just in case the callLater is
+        # triggered early. Paranoia ftw.
+        # TODO: Cancel the previous callLater rather than comparing time.time()?
+        if time.time() - self.last_request >= 2.5 * 60:
+            self.abort()
+            # Abort the underlying TLS connection. The abort() method calls
+            # loseConnection() on the underlying TLS connection which tries to
+            # shutdown the connection cleanly. We call abortConnection()
+            # since that will promptly close the underlying TCP connection.
+            self.transport.abortConnection()
+
+    def request(self, request):
+        self.last_request = time.time()
+
+        # Time this connection out if we haven't send a request in the last
+        # N minutes
+        # TODO: Cancel the previous callLater?
+        reactor.callLater(3 * 60, self._time_things_out_maybe)
+
+        d = self.conn.request(request)
+
+        def update_request_time(res):
+            self.last_request = time.time()
+            # TODO: Cancel the previous callLater?
+            reactor.callLater(3 * 60, self._time_things_out_maybe)
+            return res
+
+        d.addCallback(update_request_time)
+
+        return d
 
 
 class SpiderEndpoint(object):