diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 9a92b35361..7e7d32eb66 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,18 +18,31 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
-
-from ._base import SQLBaseStore
+from synapse.storage import background_updates
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-class RegistrationStore(SQLBaseStore):
+class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
+ self.register_background_index_update(
+ "access_tokens_device_index",
+ index_name="access_tokens_device_id",
+ table="access_tokens",
+ columns=["user_id", "device_id"],
+ )
+
+ self.register_background_index_update(
+ "refresh_tokens_device_index",
+ index_name="refresh_tokens_device_id",
+ table="refresh_tokens",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -238,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id, except_token_ids=[]):
- def f(txn):
- sql = "SELECT token FROM access_tokens WHERE user_id = ?"
+ def user_delete_access_tokens(self, user_id, except_token_ids=[],
+ device_id=None,
+ delete_refresh_tokens=False):
+ """
+ Invalidate access/refresh tokens belonging to a user
+
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_ids (list[str]): list of access_tokens which should
+ *not* be deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ delete_refresh_tokens (bool): True to delete refresh tokens as
+ well as access tokens.
+ Returns:
+ defer.Deferred:
+ """
+ def f(txn, table, except_tokens, call_after_delete):
+ sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
- if except_token_ids:
+ if device_id is not None:
+ sql += " AND device_id = ?"
+ clauses.append(device_id)
+
+ if except_tokens:
sql += " AND id NOT IN (%s)" % (
- ",".join(["?" for _ in except_token_ids]),
+ ",".join(["?" for _ in except_tokens]),
)
- clauses += except_token_ids
+ clauses += except_tokens
txn.execute(sql, clauses)
@@ -256,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
- for row in chunk:
- txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+ if call_after_delete:
+ for row in chunk:
+ txn.call_after(call_after_delete, (row[0],))
txn.execute(
- "DELETE FROM access_tokens WHERE token in (%s)" % (
+ "DELETE FROM %s WHERE token in (%s)" % (
+ table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
- yield self.runInteraction("user_delete_access_tokens", f)
+ # delete refresh tokens first, to stop new access tokens being
+ # allocated while our backs are turned
+ if delete_refresh_tokens:
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="refresh_tokens",
+ except_tokens=[],
+ call_after_delete=None,
+ )
+
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ table="access_tokens",
+ except_tokens=except_token_ids,
+ call_after_delete=self.get_user_by_access_token.invalidate,
+ )
def delete_access_token(self, access_token):
def f(txn):
@@ -288,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
- dict: Including the name (user_id) and the ID of their access token.
- Raises:
- StoreError if no user was found.
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
|