summary refs log tree commit diff
path: root/synapse/app/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/_base.py')
-rw-r--r--synapse/app/_base.py140
1 files changed, 89 insertions, 51 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index f2c1028b5d..573bb487b2 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -22,13 +22,27 @@ import socket
 import sys
 import traceback
 import warnings
-from typing import TYPE_CHECKING, Awaitable, Callable, Iterable
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    NoReturn,
+    Tuple,
+    cast,
+)
 
 from cryptography.utils import CryptographyDeprecationWarning
-from typing_extensions import NoReturn
 
 import twisted
-from twisted.internet import defer, error, reactor
+from twisted.internet import defer, error, reactor as _reactor
+from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorSSL, IReactorTCP
+from twisted.internet.protocol import ServerFactory
+from twisted.internet.tcp import Port
 from twisted.logger import LoggingFile, LogLevel
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.python.threadpool import ThreadPool
@@ -48,6 +62,7 @@ from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics import register_threadpool
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
+from synapse.types import ISynapseReactor
 from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
 from synapse.util.daemonize import daemonize_process
 from synapse.util.gai_resolver import GAIResolver
@@ -57,33 +72,44 @@ from synapse.util.versionstring import get_version_string
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+# Twisted injects the global reactor to make it easier to import, this confuses
+# mypy which thinks it is a module. Tell it that it a more proper type.
+reactor = cast(ISynapseReactor, _reactor)
+
+
 logger = logging.getLogger(__name__)
 
 # list of tuples of function, args list, kwargs dict
-_sighup_callbacks = []
+_sighup_callbacks: List[
+    Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
+] = []
 
 
-def register_sighup(func, *args, **kwargs):
+def register_sighup(func: Callable[..., None], *args: Any, **kwargs: Any) -> None:
     """
     Register a function to be called when a SIGHUP occurs.
 
     Args:
-        func (function): Function to be called when sent a SIGHUP signal.
+        func: Function to be called when sent a SIGHUP signal.
         *args, **kwargs: args and kwargs to be passed to the target function.
     """
     _sighup_callbacks.append((func, args, kwargs))
 
 
-def start_worker_reactor(appname, config, run_command=reactor.run):
+def start_worker_reactor(
+    appname: str,
+    config: HomeServerConfig,
+    run_command: Callable[[], None] = reactor.run,
+) -> None:
     """Run the reactor in the main process
 
     Daemonizes if necessary, and then configures some resources, before starting
     the reactor. Pulls configuration from the 'worker' settings in 'config'.
 
     Args:
-        appname (str): application name which will be sent to syslog
-        config (synapse.config.Config): config object
-        run_command (Callable[]): callable that actually runs the reactor
+        appname: application name which will be sent to syslog
+        config: config object
+        run_command: callable that actually runs the reactor
     """
 
     logger = logging.getLogger(config.worker.worker_app)
@@ -101,32 +127,32 @@ def start_worker_reactor(appname, config, run_command=reactor.run):
 
 
 def start_reactor(
-    appname,
-    soft_file_limit,
-    gc_thresholds,
-    pid_file,
-    daemonize,
-    print_pidfile,
-    logger,
-    run_command=reactor.run,
-):
+    appname: str,
+    soft_file_limit: int,
+    gc_thresholds: Tuple[int, int, int],
+    pid_file: str,
+    daemonize: bool,
+    print_pidfile: bool,
+    logger: logging.Logger,
+    run_command: Callable[[], None] = reactor.run,
+) -> None:
     """Run the reactor in the main process
 
     Daemonizes if necessary, and then configures some resources, before starting
     the reactor
 
     Args:
-        appname (str): application name which will be sent to syslog
-        soft_file_limit (int):
+        appname: application name which will be sent to syslog
+        soft_file_limit:
         gc_thresholds:
-        pid_file (str): name of pid file to write to if daemonize is True
-        daemonize (bool): true to run the reactor in a background process
-        print_pidfile (bool): whether to print the pid file, if daemonize is True
-        logger (logging.Logger): logger instance to pass to Daemonize
-        run_command (Callable[]): callable that actually runs the reactor
+        pid_file: name of pid file to write to if daemonize is True
+        daemonize: true to run the reactor in a background process
+        print_pidfile: whether to print the pid file, if daemonize is True
+        logger: logger instance to pass to Daemonize
+        run_command: callable that actually runs the reactor
     """
 
-    def run():
+    def run() -> None:
         logger.info("Running")
         setup_jemalloc_stats()
         change_resource_limit(soft_file_limit)
@@ -185,7 +211,7 @@ def redirect_stdio_to_logs() -> None:
     print("Redirected stdout/stderr to logs")
 
 
-def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
+def register_start(cb: Callable[..., Awaitable], *args: Any, **kwargs: Any) -> None:
     """Register a callback with the reactor, to be called once it is running
 
     This can be used to initialise parts of the system which require an asynchronous
@@ -195,7 +221,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
     will exit.
     """
 
-    async def wrapper():
+    async def wrapper() -> None:
         try:
             await cb(*args, **kwargs)
         except Exception:
@@ -224,7 +250,7 @@ def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
     reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
 
 
-def listen_metrics(bind_addresses, port):
+def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
     """
     Start Prometheus metrics server.
     """
@@ -236,11 +262,11 @@ def listen_metrics(bind_addresses, port):
 
 
 def listen_manhole(
-    bind_addresses: Iterable[str],
+    bind_addresses: Collection[str],
     port: int,
     manhole_settings: ManholeConfig,
     manhole_globals: dict,
-):
+) -> None:
     # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
     # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
     # suppress the warning for now.
@@ -259,12 +285,18 @@ def listen_manhole(
     )
 
 
-def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
+def listen_tcp(
+    bind_addresses: Collection[str],
+    port: int,
+    factory: ServerFactory,
+    reactor: IReactorTCP = reactor,
+    backlog: int = 50,
+) -> List[Port]:
     """
     Create a TCP socket for a port and several addresses
 
     Returns:
-        list[twisted.internet.tcp.Port]: listening for TCP connections
+        list of twisted.internet.tcp.Port listening for TCP connections
     """
     r = []
     for address in bind_addresses:
@@ -273,12 +305,19 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
         except error.CannotListenError as e:
             check_bind_error(e, address, bind_addresses)
 
-    return r
+    # IReactorTCP returns an object implementing IListeningPort from listenTCP,
+    # but we know it will be a Port instance.
+    return r  # type: ignore[return-value]
 
 
 def listen_ssl(
-    bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
-):
+    bind_addresses: Collection[str],
+    port: int,
+    factory: ServerFactory,
+    context_factory: IOpenSSLContextFactory,
+    reactor: IReactorSSL = reactor,
+    backlog: int = 50,
+) -> List[Port]:
     """
     Create an TLS-over-TCP socket for a port and several addresses
 
@@ -294,10 +333,13 @@ def listen_ssl(
         except error.CannotListenError as e:
             check_bind_error(e, address, bind_addresses)
 
-    return r
+    # IReactorSSL incorrectly declares that an int is returned from listenSSL,
+    # it actually returns an object implementing IListeningPort, but we know it
+    # will be a Port instance.
+    return r  # type: ignore[return-value]
 
 
-def refresh_certificate(hs: "HomeServer"):
+def refresh_certificate(hs: "HomeServer") -> None:
     """
     Refresh the TLS certificates that Synapse is using by re-reading them from
     disk and updating the TLS context factories to use them.
@@ -329,7 +371,7 @@ def refresh_certificate(hs: "HomeServer"):
         logger.info("Context factories updated.")
 
 
-async def start(hs: "HomeServer"):
+async def start(hs: "HomeServer") -> None:
     """
     Start a Synapse server or worker.
 
@@ -360,7 +402,7 @@ async def start(hs: "HomeServer"):
     if hasattr(signal, "SIGHUP"):
 
         @wrap_as_background_process("sighup")
-        def handle_sighup(*args, **kwargs):
+        def handle_sighup(*args: Any, **kwargs: Any) -> None:
             # Tell systemd our state, if we're using it. This will silently fail if
             # we're not using systemd.
             sdnotify(b"RELOADING=1")
@@ -373,7 +415,7 @@ async def start(hs: "HomeServer"):
         # We defer running the sighup handlers until next reactor tick. This
         # is so that we're in a sane state, e.g. flushing the logs may fail
         # if the sighup happens in the middle of writing a log entry.
-        def run_sighup(*args, **kwargs):
+        def run_sighup(*args: Any, **kwargs: Any) -> None:
             # `callFromThread` should be "signal safe" as well as thread
             # safe.
             reactor.callFromThread(handle_sighup, *args, **kwargs)
@@ -436,12 +478,8 @@ async def start(hs: "HomeServer"):
         atexit.register(gc.freeze)
 
 
-def setup_sentry(hs: "HomeServer"):
-    """Enable sentry integration, if enabled in configuration
-
-    Args:
-        hs
-    """
+def setup_sentry(hs: "HomeServer") -> None:
+    """Enable sentry integration, if enabled in configuration"""
 
     if not hs.config.metrics.sentry_enabled:
         return
@@ -466,7 +504,7 @@ def setup_sentry(hs: "HomeServer"):
         scope.set_tag("worker_name", name)
 
 
-def setup_sdnotify(hs: "HomeServer"):
+def setup_sdnotify(hs: "HomeServer") -> None:
     """Adds process state hooks to tell systemd what we are up to."""
 
     # Tell systemd our state, if we're using it. This will silently fail if
@@ -481,7 +519,7 @@ def setup_sdnotify(hs: "HomeServer"):
 sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
 
 
-def sdnotify(state):
+def sdnotify(state: bytes) -> None:
     """
     Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
 
@@ -490,7 +528,7 @@ def sdnotify(state):
     package which many OSes don't include as a matter of principle.
 
     Args:
-        state (bytes): notification to send
+        state: notification to send
     """
     if not isinstance(state, bytes):
         raise TypeError("sdnotify should be called with a bytes")