summary refs log tree commit diff
path: root/synapse/push/pusherpool.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push/pusherpool.py')
-rw-r--r--synapse/push/pusherpool.py61
1 files changed, 36 insertions, 25 deletions
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 43cb6e9c01..36bb5bbc65 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -14,13 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+
 from twisted.internet import defer
 
-from .pusher import PusherFactory
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.async import run_on_reactor
-
-import logging
+from synapse.push.pusher import PusherFactory
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
 logger = logging.getLogger(__name__)
 
@@ -103,23 +102,28 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id, except_access_token_id=None):
-        all = yield self.store.get_all_pushers()
-        logger.info(
-            "Removing all pushers for user %s except access tokens id %r",
-            user_id, except_access_token_id
-        )
-        for p in all:
-            if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
+    def remove_pushers_by_access_token(self, user_id, access_tokens):
+        """Remove the pushers for a given user corresponding to a set of
+        access_tokens.
+
+        Args:
+            user_id (str): user to remove pushers for
+            access_tokens (Iterable[int]): access token *ids* to remove pushers
+                for
+        """
+        tokens = set(access_tokens)
+        for p in (yield self.store.get_pushers_by_user_id(user_id)):
+            if p['access_token'] in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
-                yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+                yield self.remove_pusher(
+                    p['app_id'], p['pushkey'], p['user_name'],
+                )
 
     @defer.inlineCallbacks
     def on_new_notifications(self, min_stream_id, max_stream_id):
-        yield run_on_reactor()
         try:
             users_affected = yield self.store.get_push_action_users_in_range(
                 min_stream_id, max_stream_id
@@ -131,18 +135,20 @@ class PusherPool:
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         deferreds.append(
-                            preserve_fn(p.on_new_notifications)(
-                                min_stream_id, max_stream_id
+                            run_in_background(
+                                p.on_new_notifications,
+                                min_stream_id, max_stream_id,
                             )
                         )
 
-            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
-        except:
+            yield make_deferred_yieldable(
+                defer.gatherResults(deferreds, consumeErrors=True),
+            )
+        except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
     @defer.inlineCallbacks
     def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
-        yield run_on_reactor()
         try:
             # Need to subtract 1 from the minimum because the lower bound here
             # is not inclusive
@@ -158,11 +164,16 @@ class PusherPool:
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         deferreds.append(
-                            preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
+                            run_in_background(
+                                p.on_new_receipts,
+                                min_stream_id, max_stream_id,
+                            )
                         )
 
-            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
-        except:
+            yield make_deferred_yieldable(
+                defer.gatherResults(deferreds, consumeErrors=True),
+            )
+        except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
     @defer.inlineCallbacks
@@ -188,7 +199,7 @@ class PusherPool:
         for pusherdict in pushers:
             try:
                 p = self.pusher_factory.create_pusher(pusherdict)
-            except:
+            except Exception:
                 logger.exception("Couldn't start a pusher: caught Exception")
                 continue
             if p:
@@ -201,7 +212,7 @@ class PusherPool:
                 if appid_pushkey in byuser:
                     byuser[appid_pushkey].on_stop()
                 byuser[appid_pushkey] = p
-                preserve_fn(p.on_started)()
+                run_in_background(p.on_started)
 
         logger.info("Started pushers")