summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/auth.py16
-rw-r--r--synapse/push/pusherpool.py24
-rw-r--r--synapse/storage/registration.py10
3 files changed, 32 insertions, 18 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 080eb14271..0ba66bc947 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -664,9 +664,6 @@ class AuthHandler(BaseHandler):
         yield self.delete_access_tokens_for_user(
             user_id, except_token_id=except_access_token_id,
         )
-        yield self.hs.get_pusherpool().remove_pushers_by_user(
-            user_id, except_access_token_id
-        )
 
     @defer.inlineCallbacks
     def deactivate_account(self, user_id):
@@ -706,6 +703,12 @@ class AuthHandler(BaseHandler):
                     access_token=access_token,
                 )
 
+        # delete pushers associated with this access token
+        if user_info["token_id"] is not None:
+            yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+                str(user_info["user"]), (user_info["token_id"], )
+            )
+
     @defer.inlineCallbacks
     def delete_access_tokens_for_user(self, user_id, except_token_id=None,
                                       device_id=None):
@@ -728,13 +731,18 @@ class AuthHandler(BaseHandler):
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
             if hasattr(provider, "on_logged_out"):
-                for token, device_id in tokens_and_devices:
+                for token, token_id, device_id in tokens_and_devices:
                     yield provider.on_logged_out(
                         user_id=user_id,
                         device_id=device_id,
                         access_token=token,
                     )
 
+        # delete pushers associated with the access tokens
+        yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+            user_id, (token_id for _, token_id, _ in tokens_and_devices),
+        )
+
     @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
         # 'Canonicalise' email addresses down to lower case.
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 34cb108dcb..134e89b371 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -103,19 +103,25 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id, except_access_token_id=None):
-        all = yield self.store.get_all_pushers()
-        logger.info(
-            "Removing all pushers for user %s except access tokens id %r",
-            user_id, except_access_token_id
-        )
-        for p in all:
-            if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
+    def remove_pushers_by_access_token(self, user_id, access_tokens):
+        """Remove the pushers for a given user corresponding to a set of
+        access_tokens.
+
+        Args:
+            user_id (str): user to remove pushers for
+            access_tokens (Iterable[int]): access token *ids* to remove pushers
+                for
+        """
+        tokens = set(access_tokens)
+        for p in (yield self.store.get_pushers_by_user_id(user_id)):
+            if p['access_token'] in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
-                yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+                yield self.remove_pusher(
+                    p['app_id'], p['pushkey'], p['user_name'],
+                )
 
     @defer.inlineCallbacks
     def on_new_notifications(self, min_stream_id, max_stream_id):
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 8b9544c209..3aa810981f 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 If None, tokens associated with any device (or no device) will
                 be deleted
         Returns:
-            defer.Deferred[list[str, str|None]]: a list of the deleted tokens
-                and device IDs
+            defer.Deferred[list[str, int, str|None, int]]: a list of
+                (token, token id, device id) for each of the deleted tokens
         """
         def f(txn):
             keyvalues = {
@@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 values.append(except_token_id)
 
             txn.execute(
-                "SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
+                "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
                 values
             )
-            tokens_and_devices = [(r[0], r[1]) for r in txn]
+            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
-            for token, _ in tokens_and_devices:
+            for token, _, _ in tokens_and_devices:
                 self._invalidate_cache_and_stream(
                     txn, self.get_user_by_access_token, (token,)
                 )