summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py111
1 files changed, 32 insertions, 79 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 983ce13291..8b2c2a97ab 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, ThreepidValidationError
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
 from synapse.types import UserID
@@ -89,7 +90,8 @@ class RegistrationWorkerStore(SQLBaseStore):
             token (str): The access token of a user.
         Returns:
             defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`.
+                including the keys `name`, `is_guest`, `device_id`, `token_id`,
+                `valid_until_ms`.
         """
         return self.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
@@ -283,7 +285,7 @@ class RegistrationWorkerStore(SQLBaseStore):
     def _query_for_auth(self, txn, token):
         sql = (
             "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
-            " access_tokens.device_id"
+            " access_tokens.device_id, access_tokens.valid_until_ms"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"
@@ -432,19 +434,6 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_3pid_guest_access_token(self, medium, address):
-        ret = yield self._simple_select_one(
-            "threepid_guest_access_tokens",
-            {"medium": medium, "address": address},
-            ["guest_access_token"],
-            True,
-            "get_3pid_guest_access_token",
-        )
-        if ret:
-            defer.returnValue(ret["guest_access_token"])
-        defer.returnValue(None)
-
-    @defer.inlineCallbacks
     def get_user_id_by_threepid(self, medium, address, require_verified=False):
         """Returns user id from threepid
 
@@ -615,23 +604,29 @@ class RegistrationStore(
         )
 
         self.register_background_update_handler(
-            "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag
+            "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
         # Create a background job for culling expired 3PID validity tokens
-        hs.get_clock().looping_call(
-            self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
-        )
+        def start_cull():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "cull_expired_threepid_validation_tokens",
+                self.cull_expired_threepid_validation_tokens,
+            )
+
+        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 
     @defer.inlineCallbacks
-    def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+    def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
         """
 
         last_user = progress.get("user_id", "")
 
-        def _backgroud_update_set_deactivated_flag_txn(txn):
+        def _background_update_set_deactivated_flag_txn(txn):
             txn.execute(
                 """
                 SELECT
@@ -676,7 +671,7 @@ class RegistrationStore(
                 return False
 
         end = yield self.runInteraction(
-            "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn
+            "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
@@ -685,14 +680,16 @@ class RegistrationStore(
         defer.returnValue(batch_size)
 
     @defer.inlineCallbacks
-    def add_access_token_to_user(self, user_id, token, device_id=None):
+    def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
         """Adds an access token for the given user.
 
         Args:
             user_id (str): The user ID.
             token (str): The new access token to add.
             device_id (str): ID of the device to associate with the access
-               token
+                token
+            valid_until_ms (int|None): when the token is valid until. None for
+                no expiry.
         Raises:
             StoreError if there was a problem adding this.
         """
@@ -700,14 +697,19 @@ class RegistrationStore(
 
         yield self._simple_insert(
             "access_tokens",
-            {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
+            {
+                "id": next_id,
+                "user_id": user_id,
+                "token": token,
+                "device_id": device_id,
+                "valid_until_ms": valid_until_ms,
+            },
             desc="add_access_token_to_user",
         )
 
-    def register(
+    def register_user(
         self,
         user_id,
-        token=None,
         password_hash=None,
         was_guest=False,
         make_guest=False,
@@ -720,9 +722,6 @@ class RegistrationStore(
 
         Args:
             user_id (str): The desired user ID to register.
-            token (str): The desired access token to use for this user. If this
-                is not None, the given access token is associated with the user
-                id.
             password_hash (str): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
@@ -739,10 +738,9 @@ class RegistrationStore(
             StoreError if the user_id could not be registered.
         """
         return self.runInteraction(
-            "register",
-            self._register,
+            "register_user",
+            self._register_user,
             user_id,
-            token,
             password_hash,
             was_guest,
             make_guest,
@@ -752,11 +750,10 @@ class RegistrationStore(
             user_type,
         )
 
-    def _register(
+    def _register_user(
         self,
         txn,
         user_id,
-        token,
         password_hash,
         was_guest,
         make_guest,
@@ -769,8 +766,6 @@ class RegistrationStore(
 
         now = int(self.clock.time())
 
-        next_id = self._access_tokens_id_gen.get_next()
-
         try:
             if was_guest:
                 # Ensure that the guest user actually exists
@@ -818,14 +813,6 @@ class RegistrationStore(
         if self._account_validity.enabled:
             self.set_expiration_date_for_user_txn(txn, user_id)
 
-        if token:
-            # it's possible for this to get a conflict, but only for a single user
-            # since tokens are namespaced based on their user ID
-            txn.execute(
-                "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
-                (next_id, user_id, token),
-            )
-
         if create_profile_with_displayname:
             # set a default displayname serverside to avoid ugly race
             # between auto-joins and clients trying to set displaynames
@@ -972,40 +959,6 @@ class RegistrationStore(
 
         defer.returnValue(res if res else False)
 
-    @defer.inlineCallbacks
-    def save_or_get_3pid_guest_access_token(
-        self, medium, address, access_token, inviter_user_id
-    ):
-        """
-        Gets the 3pid's guest access token if exists, else saves access_token.
-
-        Args:
-            medium (str): Medium of the 3pid. Must be "email".
-            address (str): 3pid address.
-            access_token (str): The access token to persist if none is
-                already persisted.
-            inviter_user_id (str): User ID of the inviter.
-
-        Returns:
-            deferred str: Whichever access token is persisted at the end
-            of this function call.
-        """
-
-        def insert(txn):
-            txn.execute(
-                "INSERT INTO threepid_guest_access_tokens "
-                "(medium, address, guest_access_token, first_inviter) "
-                "VALUES (?, ?, ?, ?)",
-                (medium, address, access_token, inviter_user_id),
-            )
-
-        try:
-            yield self.runInteraction("save_3pid_guest_access_token", insert)
-            defer.returnValue(access_token)
-        except self.database_engine.module.IntegrityError:
-            ret = yield self.get_3pid_guest_access_token(medium, address)
-            defer.returnValue(ret)
-
     def add_user_pending_deactivation(self, user_id):
         """
         Adds a user to the table of users who need to be parted from all the rooms they're