summary refs log tree commit diff
path: root/synapse/app/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/_base.py')
-rw-r--r--synapse/app/_base.py107
1 files changed, 42 insertions, 65 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 08199a5e8d..d50a9840d4 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -19,7 +19,6 @@ import signal
 import sys
 import traceback
 
-import psutil
 from daemonize import Daemonize
 
 from twisted.internet import defer, error, reactor
@@ -68,21 +67,13 @@ def start_worker_reactor(appname, config):
         gc_thresholds=config.gc_thresholds,
         pid_file=config.worker_pid_file,
         daemonize=config.worker_daemonize,
-        cpu_affinity=config.worker_cpu_affinity,
         print_pidfile=config.print_pidfile,
         logger=logger,
     )
 
 
 def start_reactor(
-        appname,
-        soft_file_limit,
-        gc_thresholds,
-        pid_file,
-        daemonize,
-        cpu_affinity,
-        print_pidfile,
-        logger,
+    appname, soft_file_limit, gc_thresholds, pid_file, daemonize, print_pidfile, logger
 ):
     """ Run the reactor in the main process
 
@@ -95,7 +86,6 @@ def start_reactor(
         gc_thresholds:
         pid_file (str): name of pid file to write to if daemonize is True
         daemonize (bool): true to run the reactor in a background process
-        cpu_affinity (int|None): cpu affinity mask
         print_pidfile (bool): whether to print the pid file, if daemonize is True
         logger (logging.Logger): logger instance to pass to Daemonize
     """
@@ -109,20 +99,6 @@ def start_reactor(
         # between the sentinel and `run` logcontexts.
         with PreserveLoggingContext():
             logger.info("Running")
-            if cpu_affinity is not None:
-                # Turn the bitmask into bits, reverse it so we go from 0 up
-                mask_to_bits = bin(cpu_affinity)[2:][::-1]
-
-                cpus = []
-                cpu_num = 0
-
-                for i in mask_to_bits:
-                    if i == "1":
-                        cpus.append(cpu_num)
-                    cpu_num += 1
-
-                p = psutil.Process()
-                p.cpu_affinity(cpus)
 
             change_resource_limit(soft_file_limit)
             if gc_thresholds:
@@ -149,10 +125,10 @@ def start_reactor(
 def quit_with_error(error_string):
     message_lines = error_string.split("\n")
     line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
-    sys.stderr.write("*" * line_length + '\n')
+    sys.stderr.write("*" * line_length + "\n")
     for line in message_lines:
         sys.stderr.write(" %s\n" % (line.rstrip(),))
-    sys.stderr.write("*" * line_length + '\n')
+    sys.stderr.write("*" * line_length + "\n")
     sys.exit(1)
 
 
@@ -178,14 +154,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
     r = []
     for address in bind_addresses:
         try:
-            r.append(
-                reactor.listenTCP(
-                    port,
-                    factory,
-                    backlog,
-                    address
-                )
-            )
+            r.append(reactor.listenTCP(port, factory, backlog, address))
         except error.CannotListenError as e:
             check_bind_error(e, address, bind_addresses)
 
@@ -205,13 +174,7 @@ def listen_ssl(
     for address in bind_addresses:
         try:
             r.append(
-                reactor.listenSSL(
-                    port,
-                    factory,
-                    context_factory,
-                    backlog,
-                    address
-                )
+                reactor.listenSSL(port, factory, context_factory, backlog, address)
             )
         except error.CannotListenError as e:
             check_bind_error(e, address, bind_addresses)
@@ -243,15 +206,13 @@ def refresh_certificate(hs):
             if isinstance(i.factory, TLSMemoryBIOFactory):
                 addr = i.getHost()
                 logger.info(
-                    "Replacing TLS context factory on [%s]:%i", addr.host, addr.port,
+                    "Replacing TLS context factory on [%s]:%i", addr.host, addr.port
                 )
                 # We want to replace TLS factories with a new one, with the new
                 # TLS configuration. We do this by reaching in and pulling out
                 # the wrappedFactory, and then re-wrapping it.
                 i.factory = TLSMemoryBIOFactory(
-                    hs.tls_server_context_factory,
-                    False,
-                    i.factory.wrappedFactory
+                    hs.tls_server_context_factory, False, i.factory.wrappedFactory
                 )
         logger.info("Context factories updated.")
 
@@ -267,6 +228,7 @@ def start(hs, listeners=None):
     try:
         # Set up the SIGHUP machinery.
         if hasattr(signal, "SIGHUP"):
+
             def handle_sighup(*args, **kwargs):
                 for i in _sighup_callbacks:
                     i(hs)
@@ -302,10 +264,8 @@ def setup_sentry(hs):
         return
 
     import sentry_sdk
-    sentry_sdk.init(
-        dsn=hs.config.sentry_dsn,
-        release=get_version_string(synapse),
-    )
+
+    sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
 
     # We set some default tags that give some context to this instance
     with sentry_sdk.configure_scope() as scope:
@@ -326,7 +286,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
     many DNS queries at once
     """
     new_resolver = _LimitedHostnameResolver(
-        reactor.nameResolver, max_dns_requests_in_flight,
+        reactor.nameResolver, max_dns_requests_in_flight
     )
 
     reactor.installNameResolver(new_resolver)
@@ -339,26 +299,44 @@ class _LimitedHostnameResolver(object):
     def __init__(self, resolver, max_dns_requests_in_flight):
         self._resolver = resolver
         self._limiter = Linearizer(
-            name="dns_client_limiter", max_count=max_dns_requests_in_flight,
+            name="dns_client_limiter", max_count=max_dns_requests_in_flight
         )
 
-    def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
-                        addressTypes=None, transportSemantics='TCP'):
-        # Note this is happening deep within the reactor, so we don't need to
-        # worry about log contexts.
-
+    def resolveHostName(
+        self,
+        resolutionReceiver,
+        hostName,
+        portNumber=0,
+        addressTypes=None,
+        transportSemantics="TCP",
+    ):
         # We need this function to return `resolutionReceiver` so we do all the
         # actual logic involving deferreds in a separate function.
-        self._resolve(
-            resolutionReceiver, hostName, portNumber,
-            addressTypes, transportSemantics,
-        )
+
+        # even though this is happening within the depths of twisted, we need to drop
+        # our logcontext before starting _resolve, otherwise: (a) _resolve will drop
+        # the logcontext if it returns an incomplete deferred; (b) _resolve will
+        # call the resolutionReceiver *with* a logcontext, which it won't be expecting.
+        with PreserveLoggingContext():
+            self._resolve(
+                resolutionReceiver,
+                hostName,
+                portNumber,
+                addressTypes,
+                transportSemantics,
+            )
 
         return resolutionReceiver
 
     @defer.inlineCallbacks
-    def _resolve(self, resolutionReceiver, hostName, portNumber=0,
-                 addressTypes=None, transportSemantics='TCP'):
+    def _resolve(
+        self,
+        resolutionReceiver,
+        hostName,
+        portNumber=0,
+        addressTypes=None,
+        transportSemantics="TCP",
+    ):
 
         with (yield self._limiter.queue(())):
             # resolveHostName doesn't return a Deferred, so we need to hook into
@@ -368,8 +346,7 @@ class _LimitedHostnameResolver(object):
             receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
 
             self._resolver.resolveHostName(
-                receiver, hostName, portNumber,
-                addressTypes, transportSemantics,
+                receiver, hostName, portNumber, addressTypes, transportSemantics
             )
 
             yield deferred