summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/endpoint.py58
1 files changed, 54 insertions, 4 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 442696d393..16ad09a481 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
-from twisted.internet import defer
+from twisted.internet import defer, reactor, task
 from twisted.internet.error import ConnectError
 from twisted.names import client, dns
 from twisted.names.error import DNSNameError, DomainError
@@ -66,13 +66,63 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
         default_port = 8448
 
     if port is None:
-        return SRVClientEndpoint(
+        return _WrappingEndointFac(SRVClientEndpoint(
             reactor, "matrix", domain, protocol="tcp",
             default_port=default_port, endpoint=transport_endpoint,
             endpoint_kw_args=endpoint_kw_args
-        )
+        ))
     else:
-        return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
+        return _WrappingEndointFac(transport_endpoint(reactor, domain, port, **endpoint_kw_args))
+
+
+class _WrappingEndointFac(object):
+    def __init__(self, endpoint_fac):
+        self.endpoint_fac = endpoint_fac
+
+    @defer.inlineCallbacks
+    def connect(self, protocolFactory):
+        conn = yield self.endpoint_fac.connect(protocolFactory)
+        conn = _WrappedConnection(conn)
+        defer.returnValue(conn)
+
+
+class _WrappedConnection(object):
+    """Wraps a connection and calls abort on it if it hasn't seen any actio
+    for 5 minutes
+    """
+    __slots__ = ["conn", "last_request"]
+
+    def __init__(self, conn):
+        object.__setattr__(self, "conn", conn)
+        object.__setattr__(self, "last_request", time.time())
+
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+    def __setattr__(self, name, value):
+        setattr(self.conn, name, value)
+
+    def _time_things_out_maybe(self):
+        if time.time() - self.last_request >= 2 * 60:
+            self.abort()
+
+    def request(self, request):
+        self.last_request = time.time()
+
+        # Time this connection out if we haven't send a request in the last
+        # N minutes
+        reactor.callLater(3 * 60, self._time_things_out_maybe)
+
+        d = self.conn.request(request)
+
+        def update_request_time(res):
+            self.last_request = time.time()
+            reactor.callLater(3 * 60, self._time_things_out_maybe)
+            return res
+
+        d.addCallback(update_request_time)
+
+        return d
 
 
 class SpiderEndpoint(object):