summary refs log tree commit diff
path: root/synapse/push/pusherpool.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/pusherpool.py')
-rw-r--r--synapse/push/pusherpool.py81
1 files changed, 61 insertions, 20 deletions
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 1e0ef44fc7..2597898cf4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -94,7 +94,7 @@ class PusherPool:
             return
         run_as_background_process("start_pushers", self._start_pushers)
 
-    async def add_pusher(
+    async def add_or_update_pusher(
         self,
         user_id: str,
         access_token: Optional[int],
@@ -106,6 +106,7 @@ class PusherPool:
         lang: Optional[str],
         data: JsonDict,
         profile_tag: str = "",
+        enabled: bool = True,
     ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
@@ -147,9 +148,20 @@ class PusherPool:
                 last_stream_ordering=last_stream_ordering,
                 last_success=None,
                 failing_since=None,
+                enabled=enabled,
             )
         )
 
+        # Before we actually persist the pusher, we check if the user already has one
+        # for this app ID and pushkey. If so, we want to keep the access token in place,
+        # since this could be one device modifying (e.g. enabling/disabling) another
+        # device's pusher.
+        existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+            user_id, app_id, pushkey
+        )
+        if existing_config:
+            access_token = existing_config.access_token
+
         await self.store.add_pusher(
             user_id=user_id,
             access_token=access_token,
@@ -163,8 +175,9 @@ class PusherPool:
             data=data,
             last_stream_ordering=last_stream_ordering,
             profile_tag=profile_tag,
+            enabled=enabled,
         )
-        pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
+        pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
 
         return pusher
 
@@ -276,10 +289,25 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    async def start_pusher_by_id(
+    async def _get_pusher_config_for_user_by_app_id_and_pushkey(
+        self, user_id: str, app_id: str, pushkey: str
+    ) -> Optional[PusherConfig]:
+        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+
+        pusher_config = None
+        for r in resultlist:
+            if r.user_name == user_id:
+                pusher_config = r
+
+        return pusher_config
+
+    async def process_pusher_change_by_id(
         self, app_id: str, pushkey: str, user_id: str
     ) -> Optional[Pusher]:
-        """Look up the details for the given pusher, and start it
+        """Look up the details for the given pusher, and either start it if its
+        "enabled" flag is True, or try to stop it otherwise.
+
+        If the pusher is new and its "enabled" flag is False, the stop is a noop.
 
         Returns:
             The pusher started, if any
@@ -290,12 +318,13 @@ class PusherPool:
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
             return None
 
-        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+        pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+            user_id, app_id, pushkey
+        )
 
-        pusher_config = None
-        for r in resultlist:
-            if r.user_name == user_id:
-                pusher_config = r
+        if pusher_config and not pusher_config.enabled:
+            self.maybe_stop_pusher(app_id, pushkey, user_id)
+            return None
 
         pusher = None
         if pusher_config:
@@ -305,7 +334,7 @@ class PusherPool:
 
     async def _start_pushers(self) -> None:
         """Start all the pushers"""
-        pushers = await self.store.get_all_pushers()
+        pushers = await self.store.get_enabled_pushers()
 
         # Stagger starting up the pushers so we don't completely drown the
         # process on start up.
@@ -363,6 +392,8 @@ class PusherPool:
 
         synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
 
+        logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
+
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
@@ -382,16 +413,7 @@ class PusherPool:
         return pusher
 
     async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
-        appid_pushkey = "%s:%s" % (app_id, pushkey)
-
-        byuser = self.pushers.get(user_id, {})
-
-        if appid_pushkey in byuser:
-            logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            pusher = byuser.pop(appid_pushkey)
-            pusher.on_stop()
-
-            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
+        self.maybe_stop_pusher(app_id, pushkey, user_id)
 
         # We can only delete pushers on master.
         if self._remove_pusher_client:
@@ -402,3 +424,22 @@ class PusherPool:
             await self.store.delete_pusher_by_app_id_pushkey_user_id(
                 app_id, pushkey, user_id
             )
+
+    def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
+        """Stops a pusher with the given app ID and push key if one is running.
+
+        Args:
+            app_id: the pusher's app ID.
+            pushkey: the pusher's push key.
+            user_id: the user the pusher belongs to. Only used for logging.
+        """
+        appid_pushkey = "%s:%s" % (app_id, pushkey)
+
+        byuser = self.pushers.get(user_id, {})
+
+        if appid_pushkey in byuser:
+            logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()