summary refs log tree commit diff
path: root/synapse/app/homeserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/homeserver.py')
-rwxr-xr-xsynapse/app/homeserver.py138
1 files changed, 67 insertions, 71 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 250a17cef8..b4476bf16e 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -17,7 +17,6 @@
 import gc
 import logging
 import os
-import signal
 import sys
 import traceback
 
@@ -28,7 +27,6 @@ from prometheus_client import Gauge
 
 from twisted.application import service
 from twisted.internet import defer, reactor
-from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web.resource import EncodingResourceWrapper, NoResource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
@@ -49,7 +47,6 @@ from synapse.app import _base
 from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
-from synapse.crypto import context_factory
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.additional_resource import AdditionalResource
 from synapse.http.server import RootRedirect
@@ -86,7 +83,6 @@ def gz_wrap(r):
 
 class SynapseHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
-    _listening_services = []
 
     def _listener_http(self, config, listener_config):
         port = listener_config["port"]
@@ -102,6 +98,10 @@ class SynapseHomeServer(HomeServer):
         resources = {}
         for res in listener_config["resources"]:
             for name in res["names"]:
+                if name == "openid" and "federation" in res["names"]:
+                    # Skip loading openid resource if federation is defined
+                    # since federation resource will include openid
+                    continue
                 resources.update(self._configure_named_resource(
                     name, res.get("compress", False),
                 ))
@@ -137,6 +137,7 @@ class SynapseHomeServer(HomeServer):
                     self.version_string,
                 ),
                 self.tls_server_context_factory,
+                reactor=self.get_reactor(),
             )
 
         else:
@@ -149,7 +150,8 @@ class SynapseHomeServer(HomeServer):
                     listener_config,
                     root_resource,
                     self.version_string,
-                )
+                ),
+                reactor=self.get_reactor(),
             )
 
     def _configure_named_resource(self, name, compress=False):
@@ -196,6 +198,11 @@ class SynapseHomeServer(HomeServer):
                 FEDERATION_PREFIX: TransportLayerServer(self),
             })
 
+        if name == "openid":
+            resources.update({
+                FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]),
+            })
+
         if name in ["static", "client"]:
             resources.update({
                 STATIC_PREFIX: File(
@@ -241,10 +248,10 @@ class SynapseHomeServer(HomeServer):
 
         return resources
 
-    def start_listening(self):
+    def start_listening(self, listeners):
         config = self.get_config()
 
-        for listener in config.listeners:
+        for listener in listeners:
             if listener["type"] == "http":
                 self._listening_services.extend(
                     self._listener_http(config, listener)
@@ -328,20 +335,11 @@ def setup(config_options):
         # generating config files and shouldn't try to continue.
         sys.exit(0)
 
-    sighup_callbacks = []
     synapse.config.logger.setup_logging(
         config,
-        use_worker_options=False,
-        register_sighup=sighup_callbacks.append
+        use_worker_options=False
     )
 
-    def handle_sighup(*args, **kwargs):
-        for i in sighup_callbacks:
-            i(*args, **kwargs)
-
-    if hasattr(signal, "SIGHUP"):
-        signal.signal(signal.SIGHUP, handle_sighup)
-
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
     database_engine = create_engine(config.database_config)
@@ -377,76 +375,73 @@ def setup(config_options):
 
     hs.setup()
 
-    def refresh_certificate(*args):
+    @defer.inlineCallbacks
+    def do_acme():
         """
-        Refresh the TLS certificates that Synapse is using by re-reading them
-        from disk and updating the TLS context factories to use them.
+        Reprovision an ACME certificate, if it's required.
+
+        Returns:
+            Deferred[bool]: Whether the cert has been updated.
         """
-        logging.info("Reloading certificate from disk...")
-        hs.config.read_certificate_from_disk()
-        hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
-        hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
-            config
+        acme = hs.get_acme_handler()
+
+        # Check how long the certificate is active for.
+        cert_days_remaining = hs.config.is_disk_cert_valid(
+            allow_self_signed=False
         )
-        logging.info("Certificate reloaded.")
-
-        logging.info("Updating context factories...")
-        for i in hs._listening_services:
-            if isinstance(i.factory, TLSMemoryBIOFactory):
-                i.factory = TLSMemoryBIOFactory(
-                    hs.tls_server_context_factory,
-                    False,
-                    i.factory.wrappedFactory
-                )
-        logging.info("Context factories updated.")
 
-    sighup_callbacks.append(refresh_certificate)
+        # We want to reprovision if cert_days_remaining is None (meaning no
+        # certificate exists), or the days remaining number it returns
+        # is less than our re-registration threshold.
+        provision = False
+
+        if (cert_days_remaining is None):
+            provision = True
+
+        if cert_days_remaining > hs.config.acme_reprovision_threshold:
+            provision = True
+
+        if provision:
+            yield acme.provision_certificate()
+
+        defer.returnValue(provision)
+
+    @defer.inlineCallbacks
+    def reprovision_acme():
+        """
+        Provision a certificate from ACME, if required, and reload the TLS
+        certificate if it's renewed.
+        """
+        reprovisioned = yield do_acme()
+        if reprovisioned:
+            _base.refresh_certificate(hs)
 
     @defer.inlineCallbacks
     def start():
         try:
-            # Check if the certificate is still valid.
-            cert_days_remaining = hs.config.is_disk_cert_valid()
-
+            # Run the ACME provisioning code, if it's enabled.
             if hs.config.acme_enabled:
-                # If ACME is enabled, we might need to provision a certificate
-                # before starting.
                 acme = hs.get_acme_handler()
-
                 # Start up the webservices which we will respond to ACME
-                # challenges with.
+                # challenges with, and then provision.
                 yield acme.start_listening()
+                yield do_acme()
 
-                # We want to reprovision if cert_days_remaining is None (meaning no
-                # certificate exists), or the days remaining number it returns
-                # is less than our re-registration threshold.
-                if (cert_days_remaining is None) or (
-                    not cert_days_remaining > hs.config.acme_reprovision_threshold
-                ):
-                    yield acme.provision_certificate()
-
-            # Read the certificate from disk and build the context factories for
-            # TLS.
-            hs.config.read_certificate_from_disk()
-            hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
-            hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
-                config
-            )
+                # Check if it needs to be reprovisioned every day.
+                hs.get_clock().looping_call(
+                    reprovision_acme,
+                    24 * 60 * 60 * 1000
+                )
+
+            _base.start(hs, config.listeners)
 
-            # It is now safe to start your Synapse.
-            hs.start_listening()
             hs.get_pusherpool().start()
-            hs.get_datastore().start_profiling()
             hs.get_datastore().start_doing_background_updates()
-        except Exception as e:
-            # If a DeferredList failed (like in listening on the ACME listener),
-            # we need to print the subfailure explicitly.
-            if isinstance(e, defer.FirstError):
-                e.subFailure.printTraceback(sys.stderr)
-                sys.exit(1)
-
-            # Something else went wrong when starting. Print it and bail out.
+        except Exception:
+            # Print the exception and bail out.
             traceback.print_exc(file=sys.stderr)
+            if reactor.running:
+                reactor.stop()
             sys.exit(1)
 
     reactor.callWhenRunning(start)
@@ -455,7 +450,8 @@ def setup(config_options):
 
 
 class SynapseService(service.Service):
-    """A twisted Service class that will start synapse. Used to run synapse
+    """
+    A twisted Service class that will start synapse. Used to run synapse
     via twistd and a .tac.
     """
     def __init__(self, config):