summary refs log tree commit diff
path: root/synapse/app/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/_base.py')
-rw-r--r--synapse/app/_base.py36
1 files changed, 19 insertions, 17 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 391bd14c5c..18584226e9 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -17,6 +17,7 @@ import gc
 import logging
 import sys
 
+import psutil
 from daemonize import Daemonize
 
 from twisted.internet import error, reactor
@@ -24,12 +25,6 @@ from twisted.internet import error, reactor
 from synapse.util import PreserveLoggingContext
 from synapse.util.rlimit import change_resource_limit
 
-try:
-    import affinity
-except Exception:
-    affinity = None
-
-
 logger = logging.getLogger(__name__)
 
 
@@ -89,15 +84,20 @@ def start_reactor(
         with PreserveLoggingContext():
             logger.info("Running")
             if cpu_affinity is not None:
-                if not affinity:
-                    quit_with_error(
-                        "Missing package 'affinity' required for cpu_affinity\n"
-                        "option\n\n"
-                        "Install by running:\n\n"
-                        "   pip install affinity\n\n"
-                    )
-                logger.info("Setting CPU affinity to %s" % cpu_affinity)
-                affinity.set_process_affinity_mask(0, cpu_affinity)
+                # Turn the bitmask into bits, reverse it so we go from 0 up
+                mask_to_bits = bin(cpu_affinity)[2:][::-1]
+
+                cpus = []
+                cpu_num = 0
+
+                for i in mask_to_bits:
+                    if i == "1":
+                        cpus.append(cpu_num)
+                    cpu_num += 1
+
+                p = psutil.Process()
+                p.cpu_affinity(cpus)
+
             change_resource_limit(soft_file_limit)
             if gc_thresholds:
                 gc.set_threshold(*gc_thresholds)
@@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
         logger.info("Metrics now reporting on %s:%d", host, port)
 
 
-def listen_tcp(bind_addresses, port, factory, backlog=50):
+def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
     """
     Create a TCP socket for a port and several addresses
     """
@@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
             check_bind_error(e, address, bind_addresses)
 
 
-def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+def listen_ssl(
+    bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
+):
     """
     Create an SSL socket for a port and several addresses
     """