summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py243
1 files changed, 144 insertions, 99 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a85867936f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,14 +14,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import re
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
@@ -36,15 +38,33 @@ logger = logging.getLogger(__name__)
 
 class RegistrationWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
 
+        # Note: we don't check this sequence for consistency as we'd have to
+        # call `find_max_generated_user_id_localpart` each time, which is
+        # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
 
+        self._account_validity = hs.config.account_validity
+        if hs.config.run_background_tasks and self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
+        # Create a background job for culling expired 3PID validity tokens
+        if hs.config.run_background_tasks:
+            self.clock.looping_call(
+                self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
@@ -116,6 +136,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,
@@ -379,7 +413,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
-    ) -> str:
+    ) -> Optional[str]:
         """Look up a user by their external auth id
 
         Args:
@@ -387,7 +421,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             external_id: id on that system
 
         Returns:
-            str|None: the mxid of the user, or None if they are not known
+            the mxid of the user, or None if they are not known
         """
         return await self.db_pool.simple_select_one_onecol(
             table="user_external_ids",
@@ -761,10 +795,82 @@ class RegistrationWorkerStore(SQLBaseStore):
             "delete_threepid_session", delete_threepid_session_txn
         )
 
+    @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+    async def cull_expired_threepid_validation_tokens(self) -> None:
+        """Remove threepid validation tokens with expiry dates that have passed"""
+
+        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+            sql = """
+            DELETE FROM threepid_validation_token WHERE
+            expires < ?
+            """
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "cull_expired_threepid_validation_tokens",
+            cull_expired_threepid_validation_tokens_txn,
+            self.clock.time_msec(),
+        )
+
+    async def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db_pool.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        await self.db_pool.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
@@ -892,30 +998,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
-        self._account_validity = hs.config.account_validity
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
-
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
-
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -947,6 +1033,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
+    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn, "access_tokens", {"token": token}, "device_id"
+        )
+
+        self.db_pool.simple_update_txn(
+            txn, "access_tokens", {"token": token}, {"device_id": device_id}
+        )
+
+        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+        return old_device_id
+
+    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+        """Sets the device ID associated with an access token.
+
+        Args:
+            token: The access token to modify.
+            device_id: The new device ID.
+        Returns:
+            The old device ID associated with the access token.
+        """
+
+        return await self.db_pool.runInteraction(
+            "set_device_for_access_token",
+            self._set_device_for_access_token_txn,
+            token,
+            device_id,
+        )
+
     async def register_user(
         self,
         user_id: str,
@@ -1430,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def cull_expired_threepid_validation_tokens(self) -> None:
-        """Remove threepid validation tokens with expiry dates that have passed"""
-
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
-            sql = """
-            DELETE FROM threepid_validation_token WHERE
-            expires < ?
-            """
-            txn.execute(sql, (ts,))
-
-        await self.db_pool.runInteraction(
-            "cull_expired_threepid_validation_tokens",
-            cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
-        )
-
     async def set_user_deactivated_status(
         self, user_id: str, deactivated: bool
     ) -> None:
@@ -1475,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    async def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        await self.db_pool.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self.db_pool.simple_upsert_txn(
-            txn,
-            "account_validity",
-            keyvalues={"user_id": user_id},
-            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
-        )
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """