diff options
Diffstat (limited to 'synapse/crypto/keyclient.py')
-rw-r--r-- | synapse/crypto/keyclient.py | 75 |
1 files changed, 29 insertions, 46 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c11df5c529..5949ea0573 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -15,9 +15,10 @@ from twisted.web.http import HTTPClient +from twisted.internet.protocol import Factory from twisted.internet import defer, reactor -from twisted.internet.protocol import ClientFactory -from twisted.names.srvconnect import SRVConnector +from twisted.internet.endpoints import connectProtocol +from synapse.http.endpoint import matrix_endpoint import json import logging @@ -30,15 +31,19 @@ def fetch_server_key(server_name, ssl_context_factory): """Fetch the keys for a remote server.""" factory = SynapseKeyClientFactory() + endpoint = matrix_endpoint( + reactor, server_name, ssl_context_factory, timeout=30 + ) - SRVConnector( - reactor, "matrix", server_name, factory, - protocol="tcp", connectFuncName="connectSSL", defaultPort=443, - connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect() - - server_key, server_certificate = yield factory.remote_key - - defer.returnValue((server_key, server_certificate)) + for i in range(5): + try: + protocol = yield endpoint.connect(factory) + server_response, server_certificate = yield protocol.remote_key + defer.returnValue((server_response, server_certificate)) + return + except Exception as e: + logger.exception(e) + raise IOError("Cannot get key for %s" % server_name) class SynapseKeyClientError(Exception): @@ -51,69 +56,47 @@ class SynapseKeyClientProtocol(HTTPClient): the server and extracts the X.509 certificate for the remote peer from the SSL connection.""" + timeout = 30 + + def __init__(self): + self.remote_key = defer.Deferred() + def connectionMade(self): logger.debug("Connected to %s", self.transport.getHost()) - self.sendCommand(b"GET", b"/key") + self.sendCommand(b"GET", b"/_matrix/key/v1/") self.endHeaders() self.timer = reactor.callLater( - self.factory.timeout_seconds, + self.timeout, self.on_timeout ) def handleStatus(self, version, status, message): if status != b"200": - logger.info("Non-200 response from %s: %s %s", - self.transport.getHost(), status, message) + #logger.info("Non-200 response from %s: %s %s", + # self.transport.getHost(), status, message) self.transport.abortConnection() def handleResponse(self, response_body_bytes): try: json_response = json.loads(response_body_bytes) except ValueError: - logger.info("Invalid JSON response from %s", - self.transport.getHost()) + #logger.info("Invalid JSON response from %s", + # self.transport.getHost()) self.transport.abortConnection() return certificate = self.transport.getPeerCertificate() - self.factory.on_remote_key((json_response, certificate)) + self.remote_key.callback((json_response, certificate)) self.transport.abortConnection() self.timer.cancel() def on_timeout(self): logger.debug("Timeout waiting for response from %s", self.transport.getHost()) + self.remote_key.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() -class SynapseKeyClientFactory(ClientFactory): +class SynapseKeyClientFactory(Factory): protocol = SynapseKeyClientProtocol - max_retries = 5 - timeout_seconds = 30 - - def __init__(self): - self.succeeded = False - self.retries = 0 - self.remote_key = defer.Deferred() - def on_remote_key(self, key): - self.succeeded = True - self.remote_key.callback(key) - - def retry_connection(self, connector): - self.retries += 1 - if self.retries < self.max_retries: - connector.connector = None - connector.connect() - else: - self.remote_key.errback( - SynapseKeyClientError("Max retries exceeded")) - - def clientConnectionFailed(self, connector, reason): - logger.info("Connection failed %s", reason) - self.retry_connection(connector) - - def clientConnectionLost(self, connector, reason): - logger.info("Connection lost %s", reason) - if not self.succeeded: - self.retry_connection(connector) |