summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py229
1 files changed, 99 insertions, 130 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 9b6c28892c..0a66a2c17d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -37,13 +37,15 @@ class RegistrationWorkerStore(SQLBaseStore):
     def get_user_by_id(self, user_id):
         return self._simple_select_one(
             table="users",
-            keyvalues={
-                "name": user_id,
-            },
+            keyvalues={"name": user_id},
             retcols=[
-                "name", "password_hash", "is_guest",
-                "consent_version", "consent_server_notice_sent",
-                "appservice_id", "creation_ts",
+                "name",
+                "password_hash",
+                "is_guest",
+                "consent_version",
+                "consent_server_notice_sent",
+                "appservice_id",
+                "creation_ts",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -81,9 +83,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 including the keys `name`, `is_guest`, `device_id`, `token_id`.
         """
         return self.runInteraction(
-            "get_user_by_access_token",
-            self._query_for_auth,
-            token
+            "get_user_by_access_token", self._query_for_auth, token
         )
 
     @defer.inlineCallbacks
@@ -143,10 +143,10 @@ class RegistrationWorkerStore(SQLBaseStore):
         """Gets users that match user_id case insensitively.
         Returns a mapping of user_id -> password_hash.
         """
+
         def f(txn):
             sql = (
-                "SELECT name, password_hash FROM users"
-                " WHERE lower(name) = lower(?)"
+                "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
             return dict(txn)
@@ -156,6 +156,7 @@ class RegistrationWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
+
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users")
             rows = self.cursor_to_dict(txn)
@@ -173,6 +174,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                3) bridged users
         who registered on the homeserver in the past 24 hours
         """
+
         def _count_daily_user_type(txn):
             yesterday = int(self._clock.time()) - (60 * 60 * 24)
 
@@ -193,15 +195,18 @@ class RegistrationWorkerStore(SQLBaseStore):
             for row in txn:
                 results[row[0]] = row[1]
             return results
+
         return self.runInteraction("count_daily_user_type", _count_daily_user_type)
 
     @defer.inlineCallbacks
     def count_nonbridged_users(self):
         def _count_users(txn):
-            txn.execute("""
+            txn.execute(
+                """
                 SELECT COALESCE(COUNT(*), 0) FROM users
                 WHERE appservice_id IS NULL
-            """)
+            """
+            )
             count, = txn.fetchone()
             return count
 
@@ -220,6 +225,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         avoid the case of ID 10000000 being pre-allocated, so us wasting the
         first (and shortest) many generated user IDs.
         """
+
         def _find_next_generated_user_id(txn):
             txn.execute("SELECT name FROM users")
 
@@ -227,7 +233,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
             found = set()
 
-            for user_id, in txn:
+            for (user_id,) in txn:
                 match = regex.search(user_id)
                 if match:
                     found.add(int(match.group(1)))
@@ -235,20 +241,22 @@ class RegistrationWorkerStore(SQLBaseStore):
                 if i not in found:
                     return i
 
-        defer.returnValue((yield self.runInteraction(
-            "find_next_generated_user_id",
-            _find_next_generated_user_id
-        )))
+        defer.returnValue(
+            (
+                yield self.runInteraction(
+                    "find_next_generated_user_id", _find_next_generated_user_id
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     def get_3pid_guest_access_token(self, medium, address):
         ret = yield self._simple_select_one(
             "threepid_guest_access_tokens",
-            {
-                "medium": medium,
-                "address": address
-            },
-            ["guest_access_token"], True, 'get_3pid_guest_access_token'
+            {"medium": medium, "address": address},
+            ["guest_access_token"],
+            True,
+            'get_3pid_guest_access_token',
         )
         if ret:
             defer.returnValue(ret["guest_access_token"])
@@ -266,8 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             Deferred[str|None]: user id or None if no user id/threepid mapping exists
         """
         user_id = yield self.runInteraction(
-            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
-            medium, address
+            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
         )
         defer.returnValue(user_id)
 
@@ -285,11 +292,9 @@ class RegistrationWorkerStore(SQLBaseStore):
         ret = self._simple_select_one_txn(
             txn,
             "user_threepids",
-            {
-                "medium": medium,
-                "address": address
-            },
-            ['user_id'], True
+            {"medium": medium, "address": address},
+            ['user_id'],
+            True,
         )
         if ret:
             return ret['user_id']
@@ -297,41 +302,33 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self._simple_upsert("user_threepids", {
-            "medium": medium,
-            "address": address,
-        }, {
-            "user_id": user_id,
-            "validated_at": validated_at,
-            "added_at": added_at,
-        })
+        yield self._simple_upsert(
+            "user_threepids",
+            {"medium": medium, "address": address},
+            {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
+        )
 
     @defer.inlineCallbacks
     def user_get_threepids(self, user_id):
         ret = yield self._simple_select_list(
-            "user_threepids", {
-                "user_id": user_id
-            },
+            "user_threepids",
+            {"user_id": user_id},
             ['medium', 'address', 'validated_at', 'added_at'],
-            'user_get_threepids'
+            'user_get_threepids',
         )
         defer.returnValue(ret)
 
     def user_delete_threepid(self, user_id, medium, address):
         return self._simple_delete(
             "user_threepids",
-            keyvalues={
-                "user_id": user_id,
-                "medium": medium,
-                "address": address,
-            },
+            keyvalues={"user_id": user_id, "medium": medium, "address": address},
             desc="user_delete_threepids",
         )
 
 
-class RegistrationStore(RegistrationWorkerStore,
-                        background_updates.BackgroundUpdateStore):
-
+class RegistrationStore(
+    RegistrationWorkerStore, background_updates.BackgroundUpdateStore
+):
     def __init__(self, db_conn, hs):
         super(RegistrationStore, self).__init__(db_conn, hs)
 
@@ -372,18 +369,22 @@ class RegistrationStore(RegistrationWorkerStore,
 
         yield self._simple_insert(
             "access_tokens",
-            {
-                "id": next_id,
-                "user_id": user_id,
-                "token": token,
-                "device_id": device_id,
-            },
+            {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
             desc="add_access_token_to_user",
         )
 
-    def register(self, user_id, token=None, password_hash=None,
-                 was_guest=False, make_guest=False, appservice_id=None,
-                 create_profile_with_displayname=None, admin=False, user_type=None):
+    def register(
+        self,
+        user_id,
+        token=None,
+        password_hash=None,
+        was_guest=False,
+        make_guest=False,
+        appservice_id=None,
+        create_profile_with_displayname=None,
+        admin=False,
+        user_type=None,
+    ):
         """Attempts to register an account.
 
         Args:
@@ -417,7 +418,7 @@ class RegistrationStore(RegistrationWorkerStore,
             appservice_id,
             create_profile_with_displayname,
             admin,
-            user_type
+            user_type,
         )
 
     def _register(
@@ -447,10 +448,7 @@ class RegistrationStore(RegistrationWorkerStore,
                 self._simple_select_one_txn(
                     txn,
                     "users",
-                    keyvalues={
-                        "name": user_id,
-                        "is_guest": 1,
-                    },
+                    keyvalues={"name": user_id, "is_guest": 1},
                     retcols=("name",),
                     allow_none=False,
                 )
@@ -458,10 +456,7 @@ class RegistrationStore(RegistrationWorkerStore,
                 self._simple_update_one_txn(
                     txn,
                     "users",
-                    keyvalues={
-                        "name": user_id,
-                        "is_guest": 1,
-                    },
+                    keyvalues={"name": user_id, "is_guest": 1},
                     updatevalues={
                         "password_hash": password_hash,
                         "upgrade_ts": now,
@@ -469,7 +464,7 @@ class RegistrationStore(RegistrationWorkerStore,
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
-                    }
+                    },
                 )
             else:
                 self._simple_insert_txn(
@@ -483,20 +478,17 @@ class RegistrationStore(RegistrationWorkerStore,
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
-                    }
+                    },
                 )
         except self.database_engine.module.IntegrityError:
-            raise StoreError(
-                400, "User ID already taken.", errcode=Codes.USER_IN_USE
-            )
+            raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
         if token:
             # it's possible for this to get a conflict, but only for a single user
             # since tokens are namespaced based on their user ID
             txn.execute(
-                "INSERT INTO access_tokens(id, user_id, token)"
-                " VALUES (?,?,?)",
-                (next_id, user_id, token,)
+                "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
+                (next_id, user_id, token),
             )
 
         if create_profile_with_displayname:
@@ -507,12 +499,10 @@ class RegistrationStore(RegistrationWorkerStore,
             # while everything else uses the full mxid.
             txn.execute(
                 "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
-                (user_id_obj.localpart, create_profile_with_displayname)
+                (user_id_obj.localpart, create_profile_with_displayname),
             )
 
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_by_id, (user_id,)
-        )
+        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
     def user_set_password_hash(self, user_id, password_hash):
@@ -521,22 +511,14 @@ class RegistrationStore(RegistrationWorkerStore,
             removes most of the entries subsequently anyway so it would be
             pointless. Use flush_user separately.
         """
+
         def user_set_password_hash_txn(txn):
             self._simple_update_one_txn(
-                txn,
-                'users', {
-                    'name': user_id
-                },
-                {
-                    'password_hash': password_hash
-                }
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_user_by_id, (user_id,)
+                txn, 'users', {'name': user_id}, {'password_hash': password_hash}
             )
-        return self.runInteraction(
-            "user_set_password_hash", user_set_password_hash_txn
-        )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+        return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
 
     def user_set_consent_version(self, user_id, consent_version):
         """Updates the user table to record privacy policy consent
@@ -549,16 +531,16 @@ class RegistrationStore(RegistrationWorkerStore,
         Raises:
             StoreError(404) if user not found
         """
+
         def f(txn):
             self._simple_update_one_txn(
                 txn,
                 table='users',
-                keyvalues={'name': user_id, },
-                updatevalues={'consent_version': consent_version, },
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_user_by_id, (user_id,)
+                keyvalues={'name': user_id},
+                updatevalues={'consent_version': consent_version},
             )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
         return self.runInteraction("user_set_consent_version", f)
 
     def user_set_consent_server_notice_sent(self, user_id, consent_version):
@@ -573,20 +555,19 @@ class RegistrationStore(RegistrationWorkerStore,
         Raises:
             StoreError(404) if user not found
         """
+
         def f(txn):
             self._simple_update_one_txn(
                 txn,
                 table='users',
-                keyvalues={'name': user_id, },
-                updatevalues={'consent_server_notice_sent': consent_version, },
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_user_by_id, (user_id,)
+                keyvalues={'name': user_id},
+                updatevalues={'consent_server_notice_sent': consent_version},
             )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
         return self.runInteraction("user_set_consent_server_notice_sent", f)
 
-    def user_delete_access_tokens(self, user_id, except_token_id=None,
-                                  device_id=None):
+    def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
         """
         Invalidate access tokens belonging to a user
 
@@ -601,10 +582,9 @@ class RegistrationStore(RegistrationWorkerStore,
             defer.Deferred[list[str, int, str|None, int]]: a list of
                 (token, token id, device id) for each of the deleted tokens
         """
+
         def f(txn):
-            keyvalues = {
-                "user_id": user_id,
-            }
+            keyvalues = {"user_id": user_id}
             if device_id is not None:
                 keyvalues["device_id"] = device_id
 
@@ -616,8 +596,9 @@ class RegistrationStore(RegistrationWorkerStore,
                 values.append(except_token_id)
 
             txn.execute(
-                "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
-                values
+                "SELECT token, id, device_id FROM access_tokens WHERE %s"
+                % where_clause,
+                values,
             )
             tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
@@ -626,25 +607,16 @@ class RegistrationStore(RegistrationWorkerStore,
                     txn, self.get_user_by_access_token, (token,)
                 )
 
-            txn.execute(
-                "DELETE FROM access_tokens WHERE %s" % where_clause,
-                values
-            )
+            txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 
             return tokens_and_devices
 
-        return self.runInteraction(
-            "user_delete_access_tokens", f,
-        )
+        return self.runInteraction("user_delete_access_tokens", f)
 
     def delete_access_token(self, access_token):
         def f(txn):
             self._simple_delete_one_txn(
-                txn,
-                table="access_tokens",
-                keyvalues={
-                    "token": access_token
-                },
+                txn, table="access_tokens", keyvalues={"token": access_token}
             )
 
             self._invalidate_cache_and_stream(
@@ -667,7 +639,7 @@ class RegistrationStore(RegistrationWorkerStore,
 
     @defer.inlineCallbacks
     def save_or_get_3pid_guest_access_token(
-            self, medium, address, access_token, inviter_user_id
+        self, medium, address, access_token, inviter_user_id
     ):
         """
         Gets the 3pid's guest access token if exists, else saves access_token.
@@ -683,12 +655,13 @@ class RegistrationStore(RegistrationWorkerStore,
             deferred str: Whichever access token is persisted at the end
             of this function call.
         """
+
         def insert(txn):
             txn.execute(
                 "INSERT INTO threepid_guest_access_tokens "
                 "(medium, address, guest_access_token, first_inviter) "
                 "VALUES (?, ?, ?, ?)",
-                (medium, address, access_token, inviter_user_id)
+                (medium, address, access_token, inviter_user_id),
             )
 
         try:
@@ -705,9 +678,7 @@ class RegistrationStore(RegistrationWorkerStore,
         """
         return self._simple_insert(
             "users_pending_deactivation",
-            values={
-                "user_id": user_id,
-            },
+            values={"user_id": user_id},
             desc="add_user_pending_deactivation",
         )
 
@@ -720,9 +691,7 @@ class RegistrationStore(RegistrationWorkerStore,
         # the table, so somehow duplicate entries have ended up in it.
         return self._simple_delete(
             "users_pending_deactivation",
-            keyvalues={
-                "user_id": user_id,
-            },
+            keyvalues={"user_id": user_id},
             desc="del_user_pending_deactivation",
         )