summary refs log tree commit diff
path: root/synapse/push/pusherpool.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/pusherpool.py')
-rw-r--r--synapse/push/pusherpool.py78
1 files changed, 42 insertions, 36 deletions
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f6a5458681..2456f12f46 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,13 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import defaultdict
-from threading import Lock
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Union
+
+from prometheus_client import Gauge
 
 from twisted.internet import defer
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import PusherConfigException
 from synapse.push.emailpusher import EmailPusher
@@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
 from synapse.push.pusher import PusherFactory
 from synapse.util.async_helpers import concurrently_execute
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
+synapse_pushers = Gauge(
+    "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
+)
+
+
 class PusherPool:
     """
     The pusher pool. This is responsible for dispatching notifications of new events to
@@ -47,36 +55,20 @@ class PusherPool:
     Pusher.on_new_receipts are not expected to return deferreds.
     """
 
-    def __init__(self, _hs):
-        self.hs = _hs
-        self.pusher_factory = PusherFactory(_hs)
-        self._should_start_pushers = _hs.config.start_pushers
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.pusher_factory = PusherFactory(hs)
+        self._should_start_pushers = hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        # We shard the handling of push notifications by user ID.
+        self._pusher_shard_config = hs.config.push.pusher_shard_config
+        self._instance_name = hs.get_instance_name()
+
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
 
-        # a lock for the pushers dict, since `count_pushers` is called from an different
-        # and we otherwise get concurrent modification errors
-        self._pushers_lock = Lock()
-
-        def count_pushers():
-            results = defaultdict(int)  # type: Dict[Tuple[str, str], int]
-            with self._pushers_lock:
-                for pushers in self.pushers.values():
-                    for pusher in pushers.values():
-                        k = (type(pusher).__name__, pusher.app_id)
-                        results[k] += 1
-            return results
-
-        LaterGauge(
-            name="synapse_pushers",
-            desc="the number of active pushers",
-            labels=["kind", "app_id"],
-            caller=count_pushers,
-        )
-
     def start(self):
         """Starts the pushers off in a background process.
         """
@@ -104,6 +96,7 @@ class PusherPool:
         Returns:
             Deferred[EmailPusher|HttpPusher]
         """
+
         time_now_msec = self.clock.time_msec()
 
         # we try to create the pusher just to validate the config: it
@@ -176,6 +169,9 @@ class PusherPool:
             access_tokens (Iterable[int]): access token *ids* to remove pushers
                 for
         """
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         tokens = set(access_tokens)
         for p in (yield self.store.get_pushers_by_user_id(user_id)):
             if p["access_token"] in tokens:
@@ -237,6 +233,9 @@ class PusherPool:
         if not self._should_start_pushers:
             return
 
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
         pusher_dict = None
@@ -275,6 +274,11 @@ class PusherPool:
         Returns:
             Deferred[EmailPusher|HttpPusher]
         """
+        if not self._pusher_shard_config.should_handle(
+            self._instance_name, pusherdict["user_name"]
+        ):
+            return
+
         try:
             p = self.pusher_factory.create_pusher(pusherdict)
         except PusherConfigException as e:
@@ -298,11 +302,12 @@ class PusherPool:
 
         appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
 
-        with self._pushers_lock:
-            byuser = self.pushers.setdefault(pusherdict["user_name"], {})
-            if appid_pushkey in byuser:
-                byuser[appid_pushkey].on_stop()
-            byuser[appid_pushkey] = p
+        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        if appid_pushkey in byuser:
+            byuser[appid_pushkey].on_stop()
+        byuser[appid_pushkey] = p
+
+        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
@@ -330,9 +335,10 @@ class PusherPool:
 
         if appid_pushkey in byuser:
             logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            byuser[appid_pushkey].on_stop()
-            with self._pushers_lock:
-                del byuser[appid_pushkey]
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
 
         yield self.store.delete_pusher_by_app_id_pushkey_user_id(
             app_id, pushkey, user_id