diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c11df5c529..c26f16a038 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -15,9 +15,10 @@
from twisted.web.http import HTTPClient
+from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
-from twisted.internet.protocol import ClientFactory
-from twisted.names.srvconnect import SRVConnector
+from twisted.internet.endpoints import connectProtocol
+from synapse.http.endpoint import matrix_endpoint
import json
import logging
@@ -30,15 +31,19 @@ def fetch_server_key(server_name, ssl_context_factory):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
+ endpoint = matrix_endpoint(
+ reactor, server_name, ssl_context_factory, timeout=30
+ )
- SRVConnector(
- reactor, "matrix", server_name, factory,
- protocol="tcp", connectFuncName="connectSSL", defaultPort=443,
- connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect()
-
- server_key, server_certificate = yield factory.remote_key
-
- defer.returnValue((server_key, server_certificate))
+ for i in range(5):
+ try:
+ protocol = yield endpoint.connect(factory)
+ server_response, server_certificate = yield protocol.remote_key
+ defer.returnValue((server_response, server_certificate))
+ return
+ except Exception as e:
+ logger.exception(e)
+ raise IOError("Cannot get key for " % server_name)
class SynapseKeyClientError(Exception):
@@ -51,69 +56,47 @@ class SynapseKeyClientProtocol(HTTPClient):
the server and extracts the X.509 certificate for the remote peer from the
SSL connection."""
+ timeout = 30
+
+ def __init__(self):
+ self.remote_key = defer.Deferred()
+
def connectionMade(self):
logger.debug("Connected to %s", self.transport.getHost())
- self.sendCommand(b"GET", b"/key")
+ self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders()
self.timer = reactor.callLater(
- self.factory.timeout_seconds,
+ self.timeout,
self.on_timeout
)
def handleStatus(self, version, status, message):
if status != b"200":
- logger.info("Non-200 response from %s: %s %s",
- self.transport.getHost(), status, message)
+ #logger.info("Non-200 response from %s: %s %s",
+ # self.transport.getHost(), status, message)
self.transport.abortConnection()
def handleResponse(self, response_body_bytes):
try:
json_response = json.loads(response_body_bytes)
except ValueError:
- logger.info("Invalid JSON response from %s",
- self.transport.getHost())
+ #logger.info("Invalid JSON response from %s",
+ # self.transport.getHost())
self.transport.abortConnection()
return
certificate = self.transport.getPeerCertificate()
- self.factory.on_remote_key((json_response, certificate))
+ self.remote_key.callback((json_response, certificate))
self.transport.abortConnection()
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s",
self.transport.getHost())
+ self.on_remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
-class SynapseKeyClientFactory(ClientFactory):
+class SynapseKeyClientFactory(Factory):
protocol = SynapseKeyClientProtocol
- max_retries = 5
- timeout_seconds = 30
-
- def __init__(self):
- self.succeeded = False
- self.retries = 0
- self.remote_key = defer.Deferred()
- def on_remote_key(self, key):
- self.succeeded = True
- self.remote_key.callback(key)
-
- def retry_connection(self, connector):
- self.retries += 1
- if self.retries < self.max_retries:
- connector.connector = None
- connector.connect()
- else:
- self.remote_key.errback(
- SynapseKeyClientError("Max retries exceeded"))
-
- def clientConnectionFailed(self, connector, reason):
- logger.info("Connection failed %s", reason)
- self.retry_connection(connector)
-
- def clientConnectionLost(self, connector, reason):
- logger.info("Connection lost %s", reason)
- if not self.succeeded:
- self.retry_connection(connector)
|