summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py323
1 files changed, 237 insertions, 86 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 55e4e84d71..241a7be51e 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,9 +22,10 @@ from six import iterkeys
 from six.moves import range
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, ThreepidValidationError
+from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
@@ -56,6 +57,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 "consent_server_notice_sent",
                 "appservice_id",
                 "creation_ts",
+                "user_type",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -272,6 +274,14 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def is_server_admin(self, user):
+        """Determines if a user is an admin of this homeserver.
+
+        Args:
+            user (UserID): user ID of the user to test
+
+        Returns (bool):
+            true iff the user is a server admin, false otherwise.
+        """
         res = yield self._simple_select_one_onecol(
             table="users",
             keyvalues={"name": user.to_string()},
@@ -282,6 +292,21 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return res if res else False
 
+    def set_server_admin(self, user, admin):
+        """Sets whether a user is an admin of this homeserver.
+
+        Args:
+            user (UserID): user ID of the user to test
+            admin (bool): true iff the user is to be a server admin,
+                false otherwise.
+        """
+        return self._simple_update_one(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            updatevalues={"admin": 1 if admin else 0},
+            desc="set_server_admin",
+        )
+
     def _query_for_auth(self, txn, token):
         sql = (
             "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
@@ -299,6 +324,19 @@ class RegistrationWorkerStore(SQLBaseStore):
         return None
 
     @cachedInlineCallbacks()
+    def is_real_user(self, user_id):
+        """Determines if the user is a real user, ie does not have a 'user_type'.
+
+        Args:
+            user_id (str): user id to test
+
+        Returns:
+            Deferred[bool]: True if user 'user_type' is null or empty string
+        """
+        res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+        return res
+
+    @cachedInlineCallbacks()
     def is_support_user(self, user_id):
         """Determines if the user is of type UserTypes.SUPPORT
 
@@ -313,6 +351,16 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
         return res
 
+    def is_real_user_txn(self, txn, user_id):
+        res = self._simple_select_one_onecol_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="user_type",
+            allow_none=True,
+        )
+        return res is None
+
     def is_support_user_txn(self, txn, user_id):
         res = self._simple_select_one_onecol_txn(
             txn=txn,
@@ -337,6 +385,26 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def get_user_by_external_id(
+        self, auth_provider: str, external_id: str
+    ) -> str:
+        """Look up a user by their external auth id
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+
+        Returns:
+            str|None: the mxid of the user, or None if they are not known
+        """
+        return await self._simple_select_one_onecol(
+            table="user_external_ids",
+            keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_user_by_external_id",
+        )
+
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
@@ -398,6 +466,20 @@ class RegistrationWorkerStore(SQLBaseStore):
         return ret
 
     @defer.inlineCallbacks
+    def count_real_users(self):
+        """Counts all users without a special user_type registered on the homeserver."""
+
+        def _count_users(txn):
+            txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
+            rows = self.cursor_to_dict(txn)
+            if rows:
+                return rows[0]["users"]
+            return 0
+
+        ret = yield self.runInteraction("count_real_users", _count_users)
+        return ret
+
+    @defer.inlineCallbacks
     def find_next_generated_user_id_localpart(self):
         """
         Gets the localpart of the next generated user ID.
@@ -434,7 +516,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_user_id_by_threepid(self, medium, address, require_verified=False):
+    def get_user_id_by_threepid(self, medium, address):
         """Returns user id from threepid
 
         Args:
@@ -525,6 +607,26 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="add_user_bound_threepid",
         )
 
+    def user_get_bound_threepids(self, user_id):
+        """Get the threepids that a user has bound to an identity server through the homeserver
+        The homeserver remembers where binds to an identity server occurred. Using this
+        method can retrieve those threepids.
+
+        Args:
+            user_id (str): The ID of the user to retrieve threepids for
+
+        Returns:
+            Deferred[list[dict]]: List of dictionaries containing the following:
+                medium (str): The medium of the threepid (e.g "email")
+                address (str): The address of the threepid (e.g "bob@example.com")
+        """
+        return self._simple_select_list(
+            table="user_threepid_id_server",
+            keyvalues={"user_id": user_id},
+            retcols=["medium", "address"],
+            desc="user_get_bound_threepids",
+        )
+
     def remove_user_bound_threepid(self, user_id, medium, address, id_server):
         """The server proxied an unbind request to the given identity server on
         behalf of the given user, so we remove the mapping of threepid to
@@ -590,6 +692,98 @@ class RegistrationWorkerStore(SQLBaseStore):
         # Convert the integer into a boolean.
         return res == 1
 
+    def get_threepid_validation_session(
+        self, medium, client_secret, address=None, sid=None, validated=True
+    ):
+        """Gets a session_id and last_send_attempt (if available) for a
+        combination of validation metadata
+
+        Args:
+            medium (str|None): The medium of the 3PID
+            address (str|None): The address of the 3PID
+            sid (str|None): The ID of the validation session
+            client_secret (str): A unique string provided by the client to help identify this
+                validation attempt
+            validated (bool|None): Whether sessions should be filtered by
+                whether they have been validated already or not. None to
+                perform no filtering
+
+        Returns:
+            Deferred[dict|None]: A dict containing the following:
+                * address - address of the 3pid
+                * medium - medium of the 3pid
+                * client_secret - a secret provided by the client for this validation session
+                * session_id - ID of the validation session
+                * send_attempt - a number serving to dedupe send attempts for this session
+                * validated_at - timestamp of when this session was validated if so
+
+                Otherwise None if a validation session is not found
+        """
+        if not client_secret:
+            raise SynapseError(
+                400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
+            )
+
+        keyvalues = {"client_secret": client_secret}
+        if medium:
+            keyvalues["medium"] = medium
+        if address:
+            keyvalues["address"] = address
+        if sid:
+            keyvalues["session_id"] = sid
+
+        assert address or sid
+
+        def get_threepid_validation_session_txn(txn):
+            sql = """
+                SELECT address, session_id, medium, client_secret,
+                last_send_attempt, validated_at
+                FROM threepid_validation_session WHERE %s
+                """ % (
+                " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+            )
+
+            if validated is not None:
+                sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
+
+            sql += " LIMIT 1"
+
+            txn.execute(sql, list(keyvalues.values()))
+            rows = self.cursor_to_dict(txn)
+            if not rows:
+                return None
+
+            return rows[0]
+
+        return self.runInteraction(
+            "get_threepid_validation_session", get_threepid_validation_session_txn
+        )
+
+    def delete_threepid_session(self, session_id):
+        """Removes a threepid validation session from the database. This can
+        be done after validation has been performed and whatever action was
+        waiting on it has been carried out
+
+        Args:
+            session_id (str): The ID of the session to delete
+        """
+
+        def delete_threepid_session_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_token",
+                keyvalues={"session_id": session_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+            )
+
+        return self.runInteraction(
+            "delete_threepid_session", delete_threepid_session_txn
+        )
+
 
 class RegistrationStore(
     RegistrationWorkerStore, background_updates.BackgroundUpdateStore
@@ -671,7 +865,7 @@ class RegistrationStore(
             rows = self.cursor_to_dict(txn)
 
             if not rows:
-                return True
+                return True, 0
 
             rows_processed_nb = 0
 
@@ -687,18 +881,18 @@ class RegistrationStore(
             )
 
             if batch_size > len(rows):
-                return True
+                return True, len(rows)
             else:
-                return False
+                return False, len(rows)
 
-        end = yield self.runInteraction(
+        end, nb_processed = yield self.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
             yield self._end_background_update("users_set_deactivated_flag")
 
-        return batch_size
+        return nb_processed
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -845,9 +1039,40 @@ class RegistrationStore(
                 (user_id_obj.localpart, create_profile_with_displayname),
             )
 
+        if self.hs.config.stats_enabled:
+            # we create a new completed user statistics row
+
+            # we don't strictly need current_token since this user really can't
+            # have any state deltas before now (as it is a new user), but still,
+            # we include it for completeness.
+            current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+            self._update_stats_delta_txn(
+                txn, now, "user", user_id, {}, complete_with_stream_id=current_token
+            )
+
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> Deferred:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        return self._simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this
@@ -1047,60 +1272,6 @@ class RegistrationStore(
 
         return 1
 
-    def get_threepid_validation_session(
-        self, medium, client_secret, address=None, sid=None, validated=True
-    ):
-        """Gets a session_id and last_send_attempt (if available) for a
-        client_secret/medium/(address|session_id) combo
-
-        Args:
-            medium (str|None): The medium of the 3PID
-            address (str|None): The address of the 3PID
-            sid (str|None): The ID of the validation session
-            client_secret (str|None): A unique string provided by the client to
-                help identify this validation attempt
-            validated (bool|None): Whether sessions should be filtered by
-                whether they have been validated already or not. None to
-                perform no filtering
-
-        Returns:
-            deferred {str, int}|None: A dict containing the
-                latest session_id and send_attempt count for this 3PID.
-                Otherwise None if there hasn't been a previous attempt
-        """
-        keyvalues = {"medium": medium, "client_secret": client_secret}
-        if address:
-            keyvalues["address"] = address
-        if sid:
-            keyvalues["session_id"] = sid
-
-        assert address or sid
-
-        def get_threepid_validation_session_txn(txn):
-            sql = """
-                SELECT address, session_id, medium, client_secret,
-                last_send_attempt, validated_at
-                FROM threepid_validation_session WHERE %s
-                """ % (
-                " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
-            )
-
-            if validated is not None:
-                sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
-
-            sql += " LIMIT 1"
-
-            txn.execute(sql, list(keyvalues.values()))
-            rows = self.cursor_to_dict(txn)
-            if not rows:
-                return None
-
-            return rows[0]
-
-        return self.runInteraction(
-            "get_threepid_validation_session", get_threepid_validation_session_txn
-        )
-
     def validate_threepid_session(self, session_id, client_secret, token, current_ts):
         """Attempt to validate a threepid session using a token
 
@@ -1112,10 +1283,15 @@ class RegistrationStore(
             current_ts (int): The current unix time in milliseconds. Used for
                 checking token expiry status
 
+        Raises:
+            ThreepidValidationError: if a matching validation token was not found or has
+                expired
+
         Returns:
             deferred str|None: A str representing a link to redirect the user
             to if there is one.
         """
+
         # Insert everything into a transaction in order to run atomically
         def validate_threepid_session_txn(txn):
             row = self._simple_select_one_txn(
@@ -1287,31 +1463,6 @@ class RegistrationStore(
             self.clock.time_msec(),
         )
 
-    def delete_threepid_session(self, session_id):
-        """Removes a threepid validation session from the database. This can
-        be done after validation has been performed and whatever action was
-        waiting on it has been carried out
-
-        Args:
-            session_id (str): The ID of the session to delete
-        """
-
-        def delete_threepid_session_txn(txn):
-            self._simple_delete_txn(
-                txn,
-                table="threepid_validation_token",
-                keyvalues={"session_id": session_id},
-            )
-            self._simple_delete_txn(
-                txn,
-                table="threepid_validation_session",
-                keyvalues={"session_id": session_id},
-            )
-
-        return self.runInteraction(
-            "delete_threepid_session", delete_threepid_session_txn
-        )
-
     def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
         self._simple_update_one_txn(
             txn=txn,