summary refs log tree commit diff
path: root/synapse/app/homeserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/homeserver.py')
-rw-r--r--synapse/app/homeserver.py92
1 files changed, 38 insertions, 54 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 336c279a44..7bb3744f04 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -16,10 +16,10 @@
 import logging
 import os
 import sys
-from typing import Iterator
+from typing import Dict, Iterable, Iterator, List
 
-from twisted.internet import reactor
-from twisted.web.resource import EncodingResourceWrapper, IResource
+from twisted.internet.tcp import Port
+from twisted.web.resource import EncodingResourceWrapper, Resource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
 
@@ -76,23 +76,27 @@ from synapse.util.versionstring import get_version_string
 logger = logging.getLogger("synapse.app.homeserver")
 
 
-def gz_wrap(r):
+def gz_wrap(r: Resource) -> Resource:
     return EncodingResourceWrapper(r, [GzipEncoderFactory()])
 
 
 class SynapseHomeServer(HomeServer):
-    DATASTORE_CLASS = DataStore
+    DATASTORE_CLASS = DataStore  # type: ignore
 
-    def _listener_http(self, config: HomeServerConfig, listener_config: ListenerConfig):
+    def _listener_http(
+        self, config: HomeServerConfig, listener_config: ListenerConfig
+    ) -> Iterable[Port]:
         port = listener_config.port
         bind_addresses = listener_config.bind_addresses
         tls = listener_config.tls
+        # Must exist since this is an HTTP listener.
+        assert listener_config.http_options is not None
         site_tag = listener_config.http_options.tag
         if site_tag is None:
             site_tag = str(port)
 
         # We always include a health resource.
-        resources = {"/health": HealthResource()}
+        resources: Dict[str, Resource] = {"/health": HealthResource()}
 
         for res in listener_config.http_options.resources:
             for name in res.names:
@@ -111,7 +115,7 @@ class SynapseHomeServer(HomeServer):
                 ("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
             )
             handler = handler_cls(config, module_api)
-            if IResource.providedBy(handler):
+            if isinstance(handler, Resource):
                 resource = handler
             elif hasattr(handler, "handle_request"):
                 resource = AdditionalResource(self, handler.handle_request)
@@ -128,7 +132,7 @@ class SynapseHomeServer(HomeServer):
 
         # try to find something useful to redirect '/' to
         if WEB_CLIENT_PREFIX in resources:
-            root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
+            root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX)
         elif STATIC_PREFIX in resources:
             root_resource = RootOptionsRedirectResource(STATIC_PREFIX)
         else:
@@ -145,6 +149,8 @@ class SynapseHomeServer(HomeServer):
         )
 
         if tls:
+            # refresh_certificate should have been called before this.
+            assert self.tls_server_context_factory is not None
             ports = listen_ssl(
                 bind_addresses,
                 port,
@@ -165,20 +171,21 @@ class SynapseHomeServer(HomeServer):
 
         return ports
 
-    def _configure_named_resource(self, name, compress=False):
+    def _configure_named_resource(
+        self, name: str, compress: bool = False
+    ) -> Dict[str, Resource]:
         """Build a resource map for a named resource
 
         Args:
-            name (str): named resource: one of "client", "federation", etc
-            compress (bool): whether to enable gzip compression for this
-                resource
+            name: named resource: one of "client", "federation", etc
+            compress: whether to enable gzip compression for this resource
 
         Returns:
-            dict[str, Resource]: map from path to HTTP resource
+            map from path to HTTP resource
         """
-        resources = {}
+        resources: Dict[str, Resource] = {}
         if name == "client":
-            client_resource = ClientRestResource(self)
+            client_resource: Resource = ClientRestResource(self)
             if compress:
                 client_resource = gz_wrap(client_resource)
 
@@ -207,7 +214,7 @@ class SynapseHomeServer(HomeServer):
         if name == "consent":
             from synapse.rest.consent.consent_resource import ConsentResource
 
-            consent_resource = ConsentResource(self)
+            consent_resource: Resource = ConsentResource(self)
             if compress:
                 consent_resource = gz_wrap(consent_resource)
             resources.update({"/_matrix/consent": consent_resource})
@@ -277,7 +284,7 @@ class SynapseHomeServer(HomeServer):
 
         return resources
 
-    def start_listening(self):
+    def start_listening(self) -> None:
         if self.config.redis.redis_enabled:
             # If redis is enabled we connect via the replication command handler
             # in the same way as the workers (since we're effectively a client
@@ -303,7 +310,9 @@ class SynapseHomeServer(HomeServer):
                     ReplicationStreamProtocolFactory(self),
                 )
                 for s in services:
-                    reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
+                    self.get_reactor().addSystemEventTrigger(
+                        "before", "shutdown", s.stopListening
+                    )
             elif listener.type == "metrics":
                 if not self.config.metrics.enable_metrics:
                     logger.warning(
@@ -318,14 +327,13 @@ class SynapseHomeServer(HomeServer):
                 logger.warning("Unrecognized listener type: %s", listener.type)
 
 
-def setup(config_options):
+def setup(config_options: List[str]) -> SynapseHomeServer:
     """
     Args:
-        config_options_options: The options passed to Synapse. Usually
-            `sys.argv[1:]`.
+        config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`.
 
     Returns:
-        HomeServer
+        A homeserver instance.
     """
     try:
         config = HomeServerConfig.load_or_generate_config(
@@ -364,7 +372,7 @@ def setup(config_options):
     except Exception as e:
         handle_startup_exception(e)
 
-    async def start():
+    async def start() -> None:
         # Load the OIDC provider metadatas, if OIDC is enabled.
         if hs.config.oidc.oidc_enabled:
             oidc = hs.get_oidc_handler()
@@ -404,39 +412,15 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
 
     yield ":\n  %s" % (e.msg,)
 
-    e = e.__cause__
+    parent_e = e.__cause__
     indent = 1
-    while e:
+    while parent_e:
         indent += 1
-        yield ":\n%s%s" % ("  " * indent, str(e))
-        e = e.__cause__
-
-
-def run(hs: HomeServer):
-    PROFILE_SYNAPSE = False
-    if PROFILE_SYNAPSE:
-
-        def profile(func):
-            from cProfile import Profile
-            from threading import current_thread
-
-            def profiled(*args, **kargs):
-                profile = Profile()
-                profile.enable()
-                func(*args, **kargs)
-                profile.disable()
-                ident = current_thread().ident
-                profile.dump_stats(
-                    "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
-                )
-
-            return profiled
-
-        from twisted.python.threadpool import ThreadPool
+        yield ":\n%s%s" % ("  " * indent, str(parent_e))
+        parent_e = parent_e.__cause__
 
-        ThreadPool._worker = profile(ThreadPool._worker)
-        reactor.run = profile(reactor.run)
 
+def run(hs: HomeServer) -> None:
     _base.start_reactor(
         "synapse-homeserver",
         soft_file_limit=hs.config.server.soft_file_limit,
@@ -448,7 +432,7 @@ def run(hs: HomeServer):
     )
 
 
-def main():
+def main() -> None:
     with LoggingContext("main"):
         # check base requirements
         check_requirements()