summary refs log tree commit diff
path: root/synapse/push/pusherpool.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/pusherpool.py')
-rw-r--r--synapse/push/pusherpool.py148
1 files changed, 69 insertions, 79 deletions
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 88d203aa44..3c3262a88c 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,13 +15,10 @@
 # limitations under the License.
 
 import logging
-from collections import defaultdict
-from threading import Lock
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Union
 
-from twisted.internet import defer
+from prometheus_client import Gauge
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import PusherConfigException
 from synapse.push.emailpusher import EmailPusher
@@ -29,9 +26,18 @@ from synapse.push.httppusher import HttpPusher
 from synapse.push.pusher import PusherFactory
 from synapse.util.async_helpers import concurrently_execute
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
+synapse_pushers = Gauge(
+    "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
+)
+
+
 class PusherPool:
     """
     The pusher pool. This is responsible for dispatching notifications of new events to
@@ -44,39 +50,23 @@ class PusherPool:
     Note that it is expected that each pusher will have its own 'processing' loop which
     will send out the notifications in the background, rather than blocking until the
     notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
-    Pusher.on_new_receipts are not expected to return deferreds.
+    Pusher.on_new_receipts are not expected to return awaitables.
     """
 
-    def __init__(self, _hs):
-        self.hs = _hs
-        self.pusher_factory = PusherFactory(_hs)
-        self._should_start_pushers = _hs.config.start_pushers
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.pusher_factory = PusherFactory(hs)
+        self._should_start_pushers = hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
+        # We shard the handling of push notifications by user ID.
+        self._pusher_shard_config = hs.config.push.pusher_shard_config
+        self._instance_name = hs.get_instance_name()
+
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
 
-        # a lock for the pushers dict, since `count_pushers` is called from an different
-        # and we otherwise get concurrent modification errors
-        self._pushers_lock = Lock()
-
-        def count_pushers():
-            results = defaultdict(int)  # type: Dict[Tuple[str, str], int]
-            with self._pushers_lock:
-                for pushers in self.pushers.values():
-                    for pusher in pushers.values():
-                        k = (type(pusher).__name__, pusher.app_id)
-                        results[k] += 1
-            return results
-
-        LaterGauge(
-            name="synapse_pushers",
-            desc="the number of active pushers",
-            labels=["kind", "app_id"],
-            caller=count_pushers,
-        )
-
     def start(self):
         """Starts the pushers off in a background process.
         """
@@ -85,8 +75,7 @@ class PusherPool:
             return
         run_as_background_process("start_pushers", self._start_pushers)
 
-    @defer.inlineCallbacks
-    def add_pusher(
+    async def add_pusher(
         self,
         user_id,
         access_token,
@@ -102,8 +91,9 @@ class PusherPool:
         """Creates a new pusher and adds it to the pool
 
         Returns:
-            Deferred[EmailPusher|HttpPusher]
+            EmailPusher|HttpPusher
         """
+
         time_now_msec = self.clock.time_msec()
 
         # we try to create the pusher just to validate the config: it
@@ -131,9 +121,9 @@ class PusherPool:
         # create the pusher setting last_stream_ordering to the current maximum
         # stream ordering in event_push_actions, so it will process
         # pushes from this point onwards.
-        last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
+        last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
 
-        yield self.store.add_pusher(
+        await self.store.add_pusher(
             user_id=user_id,
             access_token=access_token,
             kind=kind,
@@ -147,15 +137,14 @@ class PusherPool:
             last_stream_ordering=last_stream_ordering,
             profile_tag=profile_tag,
         )
-        pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
+        pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
 
         return pusher
 
-    @defer.inlineCallbacks
-    def remove_pushers_by_app_id_and_pushkey_not_user(
+    async def remove_pushers_by_app_id_and_pushkey_not_user(
         self, app_id, pushkey, not_user_id
     ):
-        to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+        to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
         for p in to_remove:
             if p["user_name"] != not_user_id:
                 logger.info(
@@ -164,10 +153,9 @@ class PusherPool:
                     pushkey,
                     p["user_name"],
                 )
-                yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
 
-    @defer.inlineCallbacks
-    def remove_pushers_by_access_token(self, user_id, access_tokens):
+    async def remove_pushers_by_access_token(self, user_id, access_tokens):
         """Remove the pushers for a given user corresponding to a set of
         access_tokens.
 
@@ -176,8 +164,11 @@ class PusherPool:
             access_tokens (Iterable[int]): access token *ids* to remove pushers
                 for
         """
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
         tokens = set(access_tokens)
-        for p in (yield self.store.get_pushers_by_user_id(user_id)):
+        for p in await self.store.get_pushers_by_user_id(user_id):
             if p["access_token"] in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
@@ -185,16 +176,15 @@ class PusherPool:
                     p["pushkey"],
                     p["user_name"],
                 )
-                yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
 
-    @defer.inlineCallbacks
-    def on_new_notifications(self, min_stream_id, max_stream_id):
+    async def on_new_notifications(self, min_stream_id, max_stream_id):
         if not self.pushers:
             # nothing to do here.
             return
 
         try:
-            users_affected = yield self.store.get_push_action_users_in_range(
+            users_affected = await self.store.get_push_action_users_in_range(
                 min_stream_id, max_stream_id
             )
 
@@ -206,8 +196,7 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
-    @defer.inlineCallbacks
-    def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+    async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
         if not self.pushers:
             # nothing to do here.
             return
@@ -215,11 +204,9 @@ class PusherPool:
         try:
             # Need to subtract 1 from the minimum because the lower bound here
             # is not inclusive
-            updated_receipts = yield self.store.get_all_updated_receipts(
+            users_affected = await self.store.get_users_sent_receipts_between(
                 min_stream_id - 1, max_stream_id
             )
-            # This returns a tuple, user_id is at index 3
-            users_affected = {r[3] for r in updated_receipts}
 
             for u in users_affected:
                 if u in self.pushers:
@@ -229,17 +216,19 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    @defer.inlineCallbacks
-    def start_pusher_by_id(self, app_id, pushkey, user_id):
+    async def start_pusher_by_id(self, app_id, pushkey, user_id):
         """Look up the details for the given pusher, and start it
 
         Returns:
-            Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any
+            EmailPusher|HttpPusher|None: The pusher started, if any
         """
         if not self._should_start_pushers:
             return
 
-        resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+        if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
+            return
+
+        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
         pusher_dict = None
         for r in resultlist:
@@ -248,35 +237,35 @@ class PusherPool:
 
         pusher = None
         if pusher_dict:
-            pusher = yield self._start_pusher(pusher_dict)
+            pusher = await self._start_pusher(pusher_dict)
 
         return pusher
 
-    @defer.inlineCallbacks
-    def _start_pushers(self):
+    async def _start_pushers(self) -> None:
         """Start all the pushers
-
-        Returns:
-            Deferred
         """
-        pushers = yield self.store.get_all_pushers()
+        pushers = await self.store.get_all_pushers()
 
         # Stagger starting up the pushers so we don't completely drown the
         # process on start up.
-        yield concurrently_execute(self._start_pusher, pushers, 10)
+        await concurrently_execute(self._start_pusher, pushers, 10)
 
         logger.info("Started pushers")
 
-    @defer.inlineCallbacks
-    def _start_pusher(self, pusherdict):
+    async def _start_pusher(self, pusherdict):
         """Start the given pusher
 
         Args:
             pusherdict (dict): dict with the values pulled from the db table
 
         Returns:
-            Deferred[EmailPusher|HttpPusher]
+            EmailPusher|HttpPusher
         """
+        if not self._pusher_shard_config.should_handle(
+            self._instance_name, pusherdict["user_name"]
+        ):
+            return
+
         try:
             p = self.pusher_factory.create_pusher(pusherdict)
         except PusherConfigException as e:
@@ -300,11 +289,12 @@ class PusherPool:
 
         appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
 
-        with self._pushers_lock:
-            byuser = self.pushers.setdefault(pusherdict["user_name"], {})
-            if appid_pushkey in byuser:
-                byuser[appid_pushkey].on_stop()
-            byuser[appid_pushkey] = p
+        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        if appid_pushkey in byuser:
+            byuser[appid_pushkey].on_stop()
+        byuser[appid_pushkey] = p
+
+        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
@@ -312,7 +302,7 @@ class PusherPool:
         user_id = pusherdict["user_name"]
         last_stream_ordering = pusherdict["last_stream_ordering"]
         if last_stream_ordering:
-            have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
+            have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
             )
         else:
@@ -324,18 +314,18 @@ class PusherPool:
 
         return p
 
-    @defer.inlineCallbacks
-    def remove_pusher(self, app_id, pushkey, user_id):
+    async def remove_pusher(self, app_id, pushkey, user_id):
         appid_pushkey = "%s:%s" % (app_id, pushkey)
 
         byuser = self.pushers.get(user_id, {})
 
         if appid_pushkey in byuser:
             logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            byuser[appid_pushkey].on_stop()
-            with self._pushers_lock:
-                del byuser[appid_pushkey]
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
 
-        yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+        await self.store.delete_pusher_by_app_id_pushkey_user_id(
             app_id, pushkey, user_id
         )