summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py331
1 files changed, 206 insertions, 125 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index a78850259f..a1085ad80c 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -38,13 +38,15 @@ class RegistrationWorkerStore(SQLBaseStore):
     def get_user_by_id(self, user_id):
         return self._simple_select_one(
             table="users",
-            keyvalues={
-                "name": user_id,
-            },
+            keyvalues={"name": user_id},
             retcols=[
-                "name", "password_hash", "is_guest",
-                "consent_version", "consent_server_notice_sent",
-                "appservice_id", "creation_ts",
+                "name",
+                "password_hash",
+                "is_guest",
+                "consent_version",
+                "consent_server_notice_sent",
+                "appservice_id",
+                "creation_ts",
             ],
             allow_none=True,
             desc="get_user_by_id",
@@ -82,9 +84,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 including the keys `name`, `is_guest`, `device_id`, `token_id`.
         """
         return self.runInteraction(
-            "get_user_by_access_token",
-            self._query_for_auth,
-            token
+            "get_user_by_access_token", self._query_for_auth, token
         )
 
     @cachedInlineCallbacks()
@@ -295,10 +295,10 @@ class RegistrationWorkerStore(SQLBaseStore):
         """Gets users that match user_id case insensitively.
         Returns a mapping of user_id -> password_hash.
         """
+
         def f(txn):
             sql = (
-                "SELECT name, password_hash FROM users"
-                " WHERE lower(name) = lower(?)"
+                "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
             return dict(txn)
@@ -308,6 +308,7 @@ class RegistrationWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
+
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users")
             rows = self.cursor_to_dict(txn)
@@ -325,6 +326,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                3) bridged users
         who registered on the homeserver in the past 24 hours
         """
+
         def _count_daily_user_type(txn):
             yesterday = int(self._clock.time()) - (60 * 60 * 24)
 
@@ -345,15 +347,18 @@ class RegistrationWorkerStore(SQLBaseStore):
             for row in txn:
                 results[row[0]] = row[1]
             return results
+
         return self.runInteraction("count_daily_user_type", _count_daily_user_type)
 
     @defer.inlineCallbacks
     def count_nonbridged_users(self):
         def _count_users(txn):
-            txn.execute("""
+            txn.execute(
+                """
                 SELECT COALESCE(COUNT(*), 0) FROM users
                 WHERE appservice_id IS NULL
-            """)
+            """
+            )
             count, = txn.fetchone()
             return count
 
@@ -372,6 +377,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         avoid the case of ID 10000000 being pre-allocated, so us wasting the
         first (and shortest) many generated user IDs.
         """
+
         def _find_next_generated_user_id(txn):
             txn.execute("SELECT name FROM users")
 
@@ -379,7 +385,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
             found = set()
 
-            for user_id, in txn:
+            for (user_id,) in txn:
                 match = regex.search(user_id)
                 if match:
                     found.add(int(match.group(1)))
@@ -387,20 +393,22 @@ class RegistrationWorkerStore(SQLBaseStore):
                 if i not in found:
                     return i
 
-        defer.returnValue((yield self.runInteraction(
-            "find_next_generated_user_id",
-            _find_next_generated_user_id
-        )))
+        defer.returnValue(
+            (
+                yield self.runInteraction(
+                    "find_next_generated_user_id", _find_next_generated_user_id
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     def get_3pid_guest_access_token(self, medium, address):
         ret = yield self._simple_select_one(
             "threepid_guest_access_tokens",
-            {
-                "medium": medium,
-                "address": address
-            },
-            ["guest_access_token"], True, 'get_3pid_guest_access_token'
+            {"medium": medium, "address": address},
+            ["guest_access_token"],
+            True,
+            'get_3pid_guest_access_token',
         )
         if ret:
             defer.returnValue(ret["guest_access_token"])
@@ -418,8 +426,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             Deferred[str|None]: user id or None if no user id/threepid mapping exists
         """
         user_id = yield self.runInteraction(
-            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
-            medium, address
+            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
         )
         defer.returnValue(user_id)
 
@@ -437,11 +444,9 @@ class RegistrationWorkerStore(SQLBaseStore):
         ret = self._simple_select_one_txn(
             txn,
             "user_threepids",
-            {
-                "medium": medium,
-                "address": address
-            },
-            ['user_id'], True
+            {"medium": medium, "address": address},
+            ['user_id'],
+            True,
         )
         if ret:
             return ret['user_id']
@@ -449,41 +454,110 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self._simple_upsert("user_threepids", {
-            "medium": medium,
-            "address": address,
-        }, {
-            "user_id": user_id,
-            "validated_at": validated_at,
-            "added_at": added_at,
-        })
+        yield self._simple_upsert(
+            "user_threepids",
+            {"medium": medium, "address": address},
+            {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
+        )
 
     @defer.inlineCallbacks
     def user_get_threepids(self, user_id):
         ret = yield self._simple_select_list(
-            "user_threepids", {
-                "user_id": user_id
-            },
+            "user_threepids",
+            {"user_id": user_id},
             ['medium', 'address', 'validated_at', 'added_at'],
-            'user_get_threepids'
+            'user_get_threepids',
         )
         defer.returnValue(ret)
 
     def user_delete_threepid(self, user_id, medium, address):
         return self._simple_delete(
             "user_threepids",
+            keyvalues={"user_id": user_id, "medium": medium, "address": address},
+            desc="user_delete_threepids",
+        )
+
+    def add_user_bound_threepid(self, user_id, medium, address, id_server):
+        """The server proxied a bind request to the given identity server on
+        behalf of the given user. We need to remember this in case the user
+        asks us to unbind the threepid.
+
+        Args:
+            user_id (str)
+            medium (str)
+            address (str)
+            id_server (str)
+
+        Returns:
+            Deferred
+        """
+        # We need to use an upsert, in case they user had already bound the
+        # threepid
+        return self._simple_upsert(
+            table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
                 "medium": medium,
                 "address": address,
+                "id_server": id_server,
             },
-            desc="user_delete_threepids",
+            values={},
+            insertion_values={},
+            desc="add_user_bound_threepid",
         )
 
+    def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+        """The server proxied an unbind request to the given identity server on
+        behalf of the given user, so we remove the mapping of threepid to
+        identity server.
+
+        Args:
+            user_id (str)
+            medium (str)
+            address (str)
+            id_server (str)
+
+        Returns:
+            Deferred
+        """
+        return self._simple_delete(
+            table="user_threepid_id_server",
+            keyvalues={
+                "user_id": user_id,
+                "medium": medium,
+                "address": address,
+                "id_server": id_server,
+            },
+            desc="remove_user_bound_threepid",
+        )
 
-class RegistrationStore(RegistrationWorkerStore,
-                        background_updates.BackgroundUpdateStore):
+    def get_id_servers_user_bound(self, user_id, medium, address):
+        """Get the list of identity servers that the server proxied bind
+        requests to for given user and threepid
 
+        Args:
+            user_id (str)
+            medium (str)
+            address (str)
+
+        Returns:
+            Deferred[list[str]]: Resolves to a list of identity servers
+        """
+        return self._simple_select_onecol(
+            table="user_threepid_id_server",
+            keyvalues={
+                "user_id": user_id,
+                "medium": medium,
+                "address": address,
+            },
+            retcol="id_server",
+            desc="get_id_servers_user_bound",
+        )
+
+
+class RegistrationStore(
+    RegistrationWorkerStore, background_updates.BackgroundUpdateStore
+):
     def __init__(self, db_conn, hs):
         super(RegistrationStore, self).__init__(db_conn, hs)
 
@@ -510,6 +584,10 @@ class RegistrationStore(RegistrationWorkerStore,
         # clear the background update.
         self.register_noop_background_update("refresh_tokens_device_index")
 
+        self.register_background_update_handler(
+            "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+        )
+
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
@@ -526,18 +604,22 @@ class RegistrationStore(RegistrationWorkerStore,
 
         yield self._simple_insert(
             "access_tokens",
-            {
-                "id": next_id,
-                "user_id": user_id,
-                "token": token,
-                "device_id": device_id,
-            },
+            {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
             desc="add_access_token_to_user",
         )
 
-    def register(self, user_id, token=None, password_hash=None,
-                 was_guest=False, make_guest=False, appservice_id=None,
-                 create_profile_with_displayname=None, admin=False, user_type=None):
+    def register(
+        self,
+        user_id,
+        token=None,
+        password_hash=None,
+        was_guest=False,
+        make_guest=False,
+        appservice_id=None,
+        create_profile_with_displayname=None,
+        admin=False,
+        user_type=None,
+    ):
         """Attempts to register an account.
 
         Args:
@@ -571,7 +653,7 @@ class RegistrationStore(RegistrationWorkerStore,
             appservice_id,
             create_profile_with_displayname,
             admin,
-            user_type
+            user_type,
         )
 
     def _register(
@@ -601,10 +683,7 @@ class RegistrationStore(RegistrationWorkerStore,
                 self._simple_select_one_txn(
                     txn,
                     "users",
-                    keyvalues={
-                        "name": user_id,
-                        "is_guest": 1,
-                    },
+                    keyvalues={"name": user_id, "is_guest": 1},
                     retcols=("name",),
                     allow_none=False,
                 )
@@ -612,10 +691,7 @@ class RegistrationStore(RegistrationWorkerStore,
                 self._simple_update_one_txn(
                     txn,
                     "users",
-                    keyvalues={
-                        "name": user_id,
-                        "is_guest": 1,
-                    },
+                    keyvalues={"name": user_id, "is_guest": 1},
                     updatevalues={
                         "password_hash": password_hash,
                         "upgrade_ts": now,
@@ -623,7 +699,7 @@ class RegistrationStore(RegistrationWorkerStore,
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
-                    }
+                    },
                 )
             else:
                 self._simple_insert_txn(
@@ -637,13 +713,11 @@ class RegistrationStore(RegistrationWorkerStore,
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
-                    }
+                    },
                 )
 
         except self.database_engine.module.IntegrityError:
-            raise StoreError(
-                400, "User ID already taken.", errcode=Codes.USER_IN_USE
-            )
+            raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
         if self._account_validity.enabled:
             now_ms = self.clock.time_msec()
@@ -662,9 +736,8 @@ class RegistrationStore(RegistrationWorkerStore,
             # it's possible for this to get a conflict, but only for a single user
             # since tokens are namespaced based on their user ID
             txn.execute(
-                "INSERT INTO access_tokens(id, user_id, token)"
-                " VALUES (?,?,?)",
-                (next_id, user_id, token,)
+                "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
+                (next_id, user_id, token),
             )
 
         if create_profile_with_displayname:
@@ -675,12 +748,10 @@ class RegistrationStore(RegistrationWorkerStore,
             # while everything else uses the full mxid.
             txn.execute(
                 "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
-                (user_id_obj.localpart, create_profile_with_displayname)
+                (user_id_obj.localpart, create_profile_with_displayname),
             )
 
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_by_id, (user_id,)
-        )
+        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
     def user_set_password_hash(self, user_id, password_hash):
@@ -689,22 +760,14 @@ class RegistrationStore(RegistrationWorkerStore,
             removes most of the entries subsequently anyway so it would be
             pointless. Use flush_user separately.
         """
+
         def user_set_password_hash_txn(txn):
             self._simple_update_one_txn(
-                txn,
-                'users', {
-                    'name': user_id
-                },
-                {
-                    'password_hash': password_hash
-                }
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_user_by_id, (user_id,)
+                txn, 'users', {'name': user_id}, {'password_hash': password_hash}
             )
-        return self.runInteraction(
-            "user_set_password_hash", user_set_password_hash_txn
-        )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+        return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
 
     def user_set_consent_version(self, user_id, consent_version):
         """Updates the user table to record privacy policy consent
@@ -717,16 +780,16 @@ class RegistrationStore(RegistrationWorkerStore,
         Raises:
             StoreError(404) if user not found
         """
+
         def f(txn):
             self._simple_update_one_txn(
                 txn,
                 table='users',
-                keyvalues={'name': user_id, },
-                updatevalues={'consent_version': consent_version, },
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_user_by_id, (user_id,)
+                keyvalues={'name': user_id},
+                updatevalues={'consent_version': consent_version},
             )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
         return self.runInteraction("user_set_consent_version", f)
 
     def user_set_consent_server_notice_sent(self, user_id, consent_version):
@@ -741,20 +804,19 @@ class RegistrationStore(RegistrationWorkerStore,
         Raises:
             StoreError(404) if user not found
         """
+
         def f(txn):
             self._simple_update_one_txn(
                 txn,
                 table='users',
-                keyvalues={'name': user_id, },
-                updatevalues={'consent_server_notice_sent': consent_version, },
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_user_by_id, (user_id,)
+                keyvalues={'name': user_id},
+                updatevalues={'consent_server_notice_sent': consent_version},
             )
+            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
         return self.runInteraction("user_set_consent_server_notice_sent", f)
 
-    def user_delete_access_tokens(self, user_id, except_token_id=None,
-                                  device_id=None):
+    def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
         """
         Invalidate access tokens belonging to a user
 
@@ -769,10 +831,9 @@ class RegistrationStore(RegistrationWorkerStore,
             defer.Deferred[list[str, int, str|None, int]]: a list of
                 (token, token id, device id) for each of the deleted tokens
         """
+
         def f(txn):
-            keyvalues = {
-                "user_id": user_id,
-            }
+            keyvalues = {"user_id": user_id}
             if device_id is not None:
                 keyvalues["device_id"] = device_id
 
@@ -784,8 +845,9 @@ class RegistrationStore(RegistrationWorkerStore,
                 values.append(except_token_id)
 
             txn.execute(
-                "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
-                values
+                "SELECT token, id, device_id FROM access_tokens WHERE %s"
+                % where_clause,
+                values,
             )
             tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
@@ -794,25 +856,16 @@ class RegistrationStore(RegistrationWorkerStore,
                     txn, self.get_user_by_access_token, (token,)
                 )
 
-            txn.execute(
-                "DELETE FROM access_tokens WHERE %s" % where_clause,
-                values
-            )
+            txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 
             return tokens_and_devices
 
-        return self.runInteraction(
-            "user_delete_access_tokens", f,
-        )
+        return self.runInteraction("user_delete_access_tokens", f)
 
     def delete_access_token(self, access_token):
         def f(txn):
             self._simple_delete_one_txn(
-                txn,
-                table="access_tokens",
-                keyvalues={
-                    "token": access_token
-                },
+                txn, table="access_tokens", keyvalues={"token": access_token}
             )
 
             self._invalidate_cache_and_stream(
@@ -835,7 +888,7 @@ class RegistrationStore(RegistrationWorkerStore,
 
     @defer.inlineCallbacks
     def save_or_get_3pid_guest_access_token(
-            self, medium, address, access_token, inviter_user_id
+        self, medium, address, access_token, inviter_user_id
     ):
         """
         Gets the 3pid's guest access token if exists, else saves access_token.
@@ -851,12 +904,13 @@ class RegistrationStore(RegistrationWorkerStore,
             deferred str: Whichever access token is persisted at the end
             of this function call.
         """
+
         def insert(txn):
             txn.execute(
                 "INSERT INTO threepid_guest_access_tokens "
                 "(medium, address, guest_access_token, first_inviter) "
                 "VALUES (?, ?, ?, ?)",
-                (medium, address, access_token, inviter_user_id)
+                (medium, address, access_token, inviter_user_id),
             )
 
         try:
@@ -873,9 +927,7 @@ class RegistrationStore(RegistrationWorkerStore,
         """
         return self._simple_insert(
             "users_pending_deactivation",
-            values={
-                "user_id": user_id,
-            },
+            values={"user_id": user_id},
             desc="add_user_pending_deactivation",
         )
 
@@ -888,9 +940,7 @@ class RegistrationStore(RegistrationWorkerStore,
         # the table, so somehow duplicate entries have ended up in it.
         return self._simple_delete(
             "users_pending_deactivation",
-            keyvalues={
-                "user_id": user_id,
-            },
+            keyvalues={"user_id": user_id},
             desc="del_user_pending_deactivation",
         )
 
@@ -906,3 +956,34 @@ class RegistrationStore(RegistrationWorkerStore,
             allow_none=True,
             desc="get_users_pending_deactivation",
         )
+
+    @defer.inlineCallbacks
+    def _bg_user_threepids_grandfather(self, progress, batch_size):
+        """We now track which identity servers a user binds their 3PID to, so
+        we need to handle the case of existing bindings where we didn't track
+        this.
+
+        We do this by grandfathering in existing user threepids assuming that
+        they used one of the server configured trusted identity servers.
+        """
+
+        id_servers = set(self.config.trusted_third_party_id_servers)
+
+        def _bg_user_threepids_grandfather_txn(txn):
+            sql = """
+                INSERT INTO user_threepid_id_server
+                    (user_id, medium, address, id_server)
+                SELECT user_id, medium, address, ?
+                FROM user_threepids
+            """
+
+            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+        if id_servers:
+            yield self.runInteraction(
+                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+            )
+
+        yield self._end_background_update("user_threepids_grandfather")
+
+        defer.returnValue(1)