diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 231 |
1 files changed, 152 insertions, 79 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 1e3c2148f6..cc1b1b73c9 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -57,6 +57,7 @@ class RegistrationWorkerStore(SQLBaseStore): "consent_server_notice_sent", "appservice_id", "creation_ts", + "user_type", ], allow_none=True, desc="get_user_by_id", @@ -273,6 +274,14 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def is_server_admin(self, user): + """Determines if a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + + Returns (bool): + true iff the user is a server admin, false otherwise. + """ res = yield self._simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, @@ -283,6 +292,21 @@ class RegistrationWorkerStore(SQLBaseStore): return res if res else False + def set_server_admin(self, user, admin): + """Sets whether a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + admin (bool): true iff the user is to be a server admin, + false otherwise. + """ + return self._simple_update_one( + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"admin": 1 if admin else 0}, + desc="set_server_admin", + ) + def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id," @@ -300,6 +324,19 @@ class RegistrationWorkerStore(SQLBaseStore): return None @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + return res + + @cachedInlineCallbacks() def is_support_user(self, user_id): """Determines if the user is of type UserTypes.SUPPORT @@ -314,6 +351,16 @@ class RegistrationWorkerStore(SQLBaseStore): ) return res + def is_real_user_txn(self, txn, user_id): + res = self._simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return res is None + def is_support_user_txn(self, txn, user_id): res = self._simple_select_one_onecol_txn( txn=txn, @@ -419,6 +466,20 @@ class RegistrationWorkerStore(SQLBaseStore): return ret @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_real_users", _count_users) + return ret + + @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): """ Gets the localpart of the next generated user ID. @@ -611,6 +672,85 @@ class RegistrationWorkerStore(SQLBaseStore): # Convert the integer into a boolean. return res == 1 + def get_threepid_validation_session( + self, medium, client_secret, address=None, sid=None, validated=True + ): + """Gets a session_id and last_send_attempt (if available) for a + client_secret/medium/(address|session_id) combo + + Args: + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str|None): A unique string provided by the client to + help identify this validation attempt + validated (bool|None): Whether sessions should be filtered by + whether they have been validated already or not. None to + perform no filtering + + Returns: + deferred {str, int}|None: A dict containing the + latest session_id and send_attempt count for this 3PID. + Otherwise None if there hasn't been a previous attempt + """ + keyvalues = {"medium": medium, "client_secret": client_secret} + if address: + keyvalues["address"] = address + if sid: + keyvalues["session_id"] = sid + + assert address or sid + + def get_threepid_validation_session_txn(txn): + sql = """ + SELECT address, session_id, medium, client_secret, + last_send_attempt, validated_at + FROM threepid_validation_session WHERE %s + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) + + if validated is not None: + sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") + + sql += " LIMIT 1" + + txn.execute(sql, list(keyvalues.values())) + rows = self.cursor_to_dict(txn) + if not rows: + return None + + return rows[0] + + return self.runInteraction( + "get_threepid_validation_session", get_threepid_validation_session_txn + ) + + def delete_threepid_session(self, session_id): + """Removes a threepid validation session from the database. This can + be done after validation has been performed and whatever action was + waiting on it has been carried out + + Args: + session_id (str): The ID of the session to delete + """ + + def delete_threepid_session_txn(txn): + self._simple_delete_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id}, + ) + self._simple_delete_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + ) + + return self.runInteraction( + "delete_threepid_session", delete_threepid_session_txn + ) + class RegistrationStore( RegistrationWorkerStore, background_updates.BackgroundUpdateStore @@ -866,6 +1006,17 @@ class RegistrationStore( (user_id_obj.localpart, create_profile_with_displayname), ) + if self.hs.config.stats_enabled: + # we create a new completed user statistics row + + # we don't strictly need current_token since this user really can't + # have any state deltas before now (as it is a new user), but still, + # we include it for completeness. + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + self._update_stats_delta_txn( + txn, now, "user", user_id, {}, complete_with_stream_id=current_token + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @@ -1088,60 +1239,6 @@ class RegistrationStore( return 1 - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): - """Gets a session_id and last_send_attempt (if available) for a - client_secret/medium/(address|session_id) combo - - Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str|None): A unique string provided by the client to - help identify this validation attempt - validated (bool|None): Whether sessions should be filtered by - whether they have been validated already or not. None to - perform no filtering - - Returns: - deferred {str, int}|None: A dict containing the - latest session_id and send_attempt count for this 3PID. - Otherwise None if there hasn't been a previous attempt - """ - keyvalues = {"medium": medium, "client_secret": client_secret} - if address: - keyvalues["address"] = address - if sid: - keyvalues["session_id"] = sid - - assert address or sid - - def get_threepid_validation_session_txn(txn): - sql = """ - SELECT address, session_id, medium, client_secret, - last_send_attempt, validated_at - FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), - ) - - if validated is not None: - sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") - - sql += " LIMIT 1" - - txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) - if not rows: - return None - - return rows[0] - - return self.runInteraction( - "get_threepid_validation_session", get_threepid_validation_session_txn - ) - def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token @@ -1157,6 +1254,7 @@ class RegistrationStore( deferred str|None: A str representing a link to redirect the user to if there is one. """ + # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): row = self._simple_select_one_txn( @@ -1328,31 +1426,6 @@ class RegistrationStore( self.clock.time_msec(), ) - def delete_threepid_session(self, session_id): - """Removes a threepid validation session from the database. This can - be done after validation has been performed and whatever action was - waiting on it has been carried out - - Args: - session_id (str): The ID of the session to delete - """ - - def delete_threepid_session_txn(txn): - self._simple_delete_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id}, - ) - self._simple_delete_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - ) - - return self.runInteraction( - "delete_threepid_session", delete_threepid_session_txn - ) - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): self._simple_update_one_txn( txn=txn, |