summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-09-19 18:13:31 +0100
committerRichard van der Hoff <richard@matrix.org>2019-09-19 18:13:31 +0100
commitb65327ff661b87991566314cf07088a7e789b69b (patch)
tree75da1c79db8165ee6b3676b784a0d567551cb82d /synapse/storage/registration.py
parentbetter logging (diff)
parentfix sample config (diff)
downloadsynapse-b65327ff661b87991566314cf07088a7e789b69b.tar.xz
Merge branch 'develop' into rav/saml_mapping_work
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py231
1 files changed, 152 insertions, 79 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 1e3c2148f6..cc1b1b73c9 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -57,6 +57,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 "consent_server_notice_sent",
                 "appservice_id",
                 "creation_ts",
+                "user_type",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -273,6 +274,14 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def is_server_admin(self, user):
+        """Determines if a user is an admin of this homeserver.
+
+        Args:
+            user (UserID): user ID of the user to test
+
+        Returns (bool):
+            true iff the user is a server admin, false otherwise.
+        """
         res = yield self._simple_select_one_onecol(
             table="users",
             keyvalues={"name": user.to_string()},
@@ -283,6 +292,21 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return res if res else False
 
+    def set_server_admin(self, user, admin):
+        """Sets whether a user is an admin of this homeserver.
+
+        Args:
+            user (UserID): user ID of the user to test
+            admin (bool): true iff the user is to be a server admin,
+                false otherwise.
+        """
+        return self._simple_update_one(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            updatevalues={"admin": 1 if admin else 0},
+            desc="set_server_admin",
+        )
+
     def _query_for_auth(self, txn, token):
         sql = (
             "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
@@ -300,6 +324,19 @@ class RegistrationWorkerStore(SQLBaseStore):
         return None
 
     @cachedInlineCallbacks()
+    def is_real_user(self, user_id):
+        """Determines if the user is a real user, ie does not have a 'user_type'.
+
+        Args:
+            user_id (str): user id to test
+
+        Returns:
+            Deferred[bool]: True if user 'user_type' is null or empty string
+        """
+        res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+        return res
+
+    @cachedInlineCallbacks()
     def is_support_user(self, user_id):
         """Determines if the user is of type UserTypes.SUPPORT
 
@@ -314,6 +351,16 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
         return res
 
+    def is_real_user_txn(self, txn, user_id):
+        res = self._simple_select_one_onecol_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="user_type",
+            allow_none=True,
+        )
+        return res is None
+
     def is_support_user_txn(self, txn, user_id):
         res = self._simple_select_one_onecol_txn(
             txn=txn,
@@ -419,6 +466,20 @@ class RegistrationWorkerStore(SQLBaseStore):
         return ret
 
     @defer.inlineCallbacks
+    def count_real_users(self):
+        """Counts all users without a special user_type registered on the homeserver."""
+
+        def _count_users(txn):
+            txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
+            rows = self.cursor_to_dict(txn)
+            if rows:
+                return rows[0]["users"]
+            return 0
+
+        ret = yield self.runInteraction("count_real_users", _count_users)
+        return ret
+
+    @defer.inlineCallbacks
     def find_next_generated_user_id_localpart(self):
         """
         Gets the localpart of the next generated user ID.
@@ -611,6 +672,85 @@ class RegistrationWorkerStore(SQLBaseStore):
         # Convert the integer into a boolean.
         return res == 1
 
+    def get_threepid_validation_session(
+        self, medium, client_secret, address=None, sid=None, validated=True
+    ):
+        """Gets a session_id and last_send_attempt (if available) for a
+        client_secret/medium/(address|session_id) combo
+
+        Args:
+            medium (str|None): The medium of the 3PID
+            address (str|None): The address of the 3PID
+            sid (str|None): The ID of the validation session
+            client_secret (str|None): A unique string provided by the client to
+                help identify this validation attempt
+            validated (bool|None): Whether sessions should be filtered by
+                whether they have been validated already or not. None to
+                perform no filtering
+
+        Returns:
+            deferred {str, int}|None: A dict containing the
+                latest session_id and send_attempt count for this 3PID.
+                Otherwise None if there hasn't been a previous attempt
+        """
+        keyvalues = {"medium": medium, "client_secret": client_secret}
+        if address:
+            keyvalues["address"] = address
+        if sid:
+            keyvalues["session_id"] = sid
+
+        assert address or sid
+
+        def get_threepid_validation_session_txn(txn):
+            sql = """
+                SELECT address, session_id, medium, client_secret,
+                last_send_attempt, validated_at
+                FROM threepid_validation_session WHERE %s
+                """ % (
+                " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+            )
+
+            if validated is not None:
+                sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
+
+            sql += " LIMIT 1"
+
+            txn.execute(sql, list(keyvalues.values()))
+            rows = self.cursor_to_dict(txn)
+            if not rows:
+                return None
+
+            return rows[0]
+
+        return self.runInteraction(
+            "get_threepid_validation_session", get_threepid_validation_session_txn
+        )
+
+    def delete_threepid_session(self, session_id):
+        """Removes a threepid validation session from the database. This can
+        be done after validation has been performed and whatever action was
+        waiting on it has been carried out
+
+        Args:
+            session_id (str): The ID of the session to delete
+        """
+
+        def delete_threepid_session_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_token",
+                keyvalues={"session_id": session_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+            )
+
+        return self.runInteraction(
+            "delete_threepid_session", delete_threepid_session_txn
+        )
+
 
 class RegistrationStore(
     RegistrationWorkerStore, background_updates.BackgroundUpdateStore
@@ -866,6 +1006,17 @@ class RegistrationStore(
                 (user_id_obj.localpart, create_profile_with_displayname),
             )
 
+        if self.hs.config.stats_enabled:
+            # we create a new completed user statistics row
+
+            # we don't strictly need current_token since this user really can't
+            # have any state deltas before now (as it is a new user), but still,
+            # we include it for completeness.
+            current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+            self._update_stats_delta_txn(
+                txn, now, "user", user_id, {}, complete_with_stream_id=current_token
+            )
+
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
@@ -1088,60 +1239,6 @@ class RegistrationStore(
 
         return 1
 
-    def get_threepid_validation_session(
-        self, medium, client_secret, address=None, sid=None, validated=True
-    ):
-        """Gets a session_id and last_send_attempt (if available) for a
-        client_secret/medium/(address|session_id) combo
-
-        Args:
-            medium (str|None): The medium of the 3PID
-            address (str|None): The address of the 3PID
-            sid (str|None): The ID of the validation session
-            client_secret (str|None): A unique string provided by the client to
-                help identify this validation attempt
-            validated (bool|None): Whether sessions should be filtered by
-                whether they have been validated already or not. None to
-                perform no filtering
-
-        Returns:
-            deferred {str, int}|None: A dict containing the
-                latest session_id and send_attempt count for this 3PID.
-                Otherwise None if there hasn't been a previous attempt
-        """
-        keyvalues = {"medium": medium, "client_secret": client_secret}
-        if address:
-            keyvalues["address"] = address
-        if sid:
-            keyvalues["session_id"] = sid
-
-        assert address or sid
-
-        def get_threepid_validation_session_txn(txn):
-            sql = """
-                SELECT address, session_id, medium, client_secret,
-                last_send_attempt, validated_at
-                FROM threepid_validation_session WHERE %s
-                """ % (
-                " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
-            )
-
-            if validated is not None:
-                sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
-
-            sql += " LIMIT 1"
-
-            txn.execute(sql, list(keyvalues.values()))
-            rows = self.cursor_to_dict(txn)
-            if not rows:
-                return None
-
-            return rows[0]
-
-        return self.runInteraction(
-            "get_threepid_validation_session", get_threepid_validation_session_txn
-        )
-
     def validate_threepid_session(self, session_id, client_secret, token, current_ts):
         """Attempt to validate a threepid session using a token
 
@@ -1157,6 +1254,7 @@ class RegistrationStore(
             deferred str|None: A str representing a link to redirect the user
             to if there is one.
         """
+
         # Insert everything into a transaction in order to run atomically
         def validate_threepid_session_txn(txn):
             row = self._simple_select_one_txn(
@@ -1328,31 +1426,6 @@ class RegistrationStore(
             self.clock.time_msec(),
         )
 
-    def delete_threepid_session(self, session_id):
-        """Removes a threepid validation session from the database. This can
-        be done after validation has been performed and whatever action was
-        waiting on it has been carried out
-
-        Args:
-            session_id (str): The ID of the session to delete
-        """
-
-        def delete_threepid_session_txn(txn):
-            self._simple_delete_txn(
-                txn,
-                table="threepid_validation_token",
-                keyvalues={"session_id": session_id},
-            )
-            self._simple_delete_txn(
-                txn,
-                table="threepid_validation_session",
-                keyvalues={"session_id": session_id},
-            )
-
-        return self.runInteraction(
-            "delete_threepid_session", delete_threepid_session_txn
-        )
-
     def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
         self._simple_update_one_txn(
             txn=txn,