summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4522.feature1
-rw-r--r--synapse/app/_base.py19
-rwxr-xr-xsynapse/app/homeserver.py79
-rw-r--r--synapse/config/tls.py12
-rw-r--r--synapse/server.py3
5 files changed, 89 insertions, 25 deletions
diff --git a/changelog.d/4522.feature b/changelog.d/4522.feature
new file mode 100644
index 0000000000..ef18daf60b
--- /dev/null
+++ b/changelog.d/4522.feature
@@ -0,0 +1 @@
+Synapse's ACME support will now correctly reprovision a certificate that approaches its expiry while Synapse is running.
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3cbb003035..62c633146f 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -23,6 +23,7 @@ import psutil
 from daemonize import Daemonize
 
 from twisted.internet import error, reactor
+from twisted.protocols.tls import TLSMemoryBIOFactory
 
 from synapse.app import check_bind_error
 from synapse.crypto import context_factory
@@ -220,6 +221,24 @@ def refresh_certificate(hs):
     )
     logging.info("Certificate loaded.")
 
+    if hs._listening_services:
+        logging.info("Updating context factories...")
+        for i in hs._listening_services:
+            # When you listenSSL, it doesn't make an SSL port but a TCP one with
+            # a TLS wrapping factory around the factory you actually want to get
+            # requests. This factory attribute is public but missing from
+            # Twisted's documentation.
+            if isinstance(i.factory, TLSMemoryBIOFactory):
+                # We want to replace TLS factories with a new one, with the new
+                # TLS configuration. We do this by reaching in and pulling out
+                # the wrappedFactory, and then re-wrapping it.
+                i.factory = TLSMemoryBIOFactory(
+                    hs.tls_server_context_factory,
+                    False,
+                    i.factory.wrappedFactory
+                )
+        logging.info("Context factories updated.")
+
 
 def start(hs, listeners=None):
     """
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index d1cab07bb6..b4476bf16e 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -83,7 +83,6 @@ def gz_wrap(r):
 
 class SynapseHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
-    _listening_services = []
 
     def _listener_http(self, config, listener_config):
         port = listener_config["port"]
@@ -377,41 +376,72 @@ def setup(config_options):
     hs.setup()
 
     @defer.inlineCallbacks
+    def do_acme():
+        """
+        Reprovision an ACME certificate, if it's required.
+
+        Returns:
+            Deferred[bool]: Whether the cert has been updated.
+        """
+        acme = hs.get_acme_handler()
+
+        # Check how long the certificate is active for.
+        cert_days_remaining = hs.config.is_disk_cert_valid(
+            allow_self_signed=False
+        )
+
+        # We want to reprovision if cert_days_remaining is None (meaning no
+        # certificate exists), or the days remaining number it returns
+        # is less than our re-registration threshold.
+        provision = False
+
+        if (cert_days_remaining is None):
+            provision = True
+
+        if cert_days_remaining > hs.config.acme_reprovision_threshold:
+            provision = True
+
+        if provision:
+            yield acme.provision_certificate()
+
+        defer.returnValue(provision)
+
+    @defer.inlineCallbacks
+    def reprovision_acme():
+        """
+        Provision a certificate from ACME, if required, and reload the TLS
+        certificate if it's renewed.
+        """
+        reprovisioned = yield do_acme()
+        if reprovisioned:
+            _base.refresh_certificate(hs)
+
+    @defer.inlineCallbacks
     def start():
         try:
-            # Check if the certificate is still valid.
-            cert_days_remaining = hs.config.is_disk_cert_valid()
-
+            # Run the ACME provisioning code, if it's enabled.
             if hs.config.acme_enabled:
-                # If ACME is enabled, we might need to provision a certificate
-                # before starting.
                 acme = hs.get_acme_handler()
-
                 # Start up the webservices which we will respond to ACME
-                # challenges with.
+                # challenges with, and then provision.
                 yield acme.start_listening()
+                yield do_acme()
 
-                # We want to reprovision if cert_days_remaining is None (meaning no
-                # certificate exists), or the days remaining number it returns
-                # is less than our re-registration threshold.
-                if (cert_days_remaining is None) or (
-                    not cert_days_remaining > hs.config.acme_reprovision_threshold
-                ):
-                    yield acme.provision_certificate()
+                # Check if it needs to be reprovisioned every day.
+                hs.get_clock().looping_call(
+                    reprovision_acme,
+                    24 * 60 * 60 * 1000
+                )
 
             _base.start(hs, config.listeners)
 
             hs.get_pusherpool().start()
             hs.get_datastore().start_doing_background_updates()
-        except Exception as e:
-            # If a DeferredList failed (like in listening on the ACME listener),
-            # we need to print the subfailure explicitly.
-            if isinstance(e, defer.FirstError):
-                e.subFailure.printTraceback(sys.stderr)
-                sys.exit(1)
-
-            # Something else went wrong when starting. Print it and bail out.
+        except Exception:
+            # Print the exception and bail out.
             traceback.print_exc(file=sys.stderr)
+            if reactor.running:
+                reactor.stop()
             sys.exit(1)
 
     reactor.callWhenRunning(start)
@@ -420,7 +450,8 @@ def setup(config_options):
 
 
 class SynapseService(service.Service):
-    """A twisted Service class that will start synapse. Used to run synapse
+    """
+    A twisted Service class that will start synapse. Used to run synapse
     via twistd and a .tac.
     """
     def __init__(self, config):
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 81b3a659fe..9fcc79816d 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -64,10 +64,14 @@ class TlsConfig(Config):
         self.tls_certificate = None
         self.tls_private_key = None
 
-    def is_disk_cert_valid(self):
+    def is_disk_cert_valid(self, allow_self_signed=True):
         """
         Is the certificate we have on disk valid, and if so, for how long?
 
+        Args:
+            allow_self_signed (bool): Should we allow the certificate we
+                read to be self signed?
+
         Returns:
             int: Days remaining of certificate validity.
             None: No certificate exists.
@@ -88,6 +92,12 @@ class TlsConfig(Config):
             logger.exception("Failed to parse existing certificate off disk!")
             raise
 
+        if not allow_self_signed:
+            if tls_certificate.get_subject() == tls_certificate.get_issuer():
+                raise ValueError(
+                    "TLS Certificate is self signed, and this is not permitted"
+                )
+
         # YYYYMMDDhhmmssZ -- in UTC
         expires_on = datetime.strptime(
             tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
diff --git a/synapse/server.py b/synapse/server.py
index 6c52101616..a2cf8a91cd 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -112,6 +112,8 @@ class HomeServer(object):
 
     Attributes:
         config (synapse.config.homeserver.HomeserverConfig):
+        _listening_services (list[twisted.internet.tcp.Port]): TCP ports that
+            we are listening on to provide HTTP services.
     """
 
     __metaclass__ = abc.ABCMeta
@@ -196,6 +198,7 @@ class HomeServer(object):
         self._reactor = reactor
         self.hostname = hostname
         self._building = {}
+        self._listening_services = []
 
         self.clock = Clock(reactor)
         self.distributor = Distributor()