summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 0a66a2c17d..e30b86c346 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -325,6 +325,83 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="user_delete_threepids",
         )
 
+    def add_user_bound_threepid(self, user_id, medium, address, id_server):
+        """The server proxied a bind request to the given identity server on
+        behalf of the given user. We need to remember this in case the user
+        asks us to unbind the threepid.
+
+        Args:
+            user_id (str)
+            medium (str)
+            address (str)
+            id_server (str)
+
+        Returns:
+            Deferred
+        """
+        # We need to use an upsert, in case they user had already bound the
+        # threepid
+        return self._simple_upsert(
+            table="user_threepid_id_server",
+            keyvalues={
+                "user_id": user_id,
+                "medium": medium,
+                "address": address,
+                "id_server": id_server,
+            },
+            values={},
+            insertion_values={},
+            desc="add_user_bound_threepid",
+        )
+
+    def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+        """The server proxied an unbind request to the given identity server on
+        behalf of the given user, so we remove the mapping of threepid to
+        identity server.
+
+        Args:
+            user_id (str)
+            medium (str)
+            address (str)
+            id_server (str)
+
+        Returns:
+            Deferred
+        """
+        return self._simple_delete(
+            table="user_threepid_id_server",
+            keyvalues={
+                "user_id": user_id,
+                "medium": medium,
+                "address": address,
+                "id_server": id_server,
+            },
+            desc="remove_user_bound_threepid",
+        )
+
+    def get_id_servers_user_bound(self, user_id, medium, address):
+        """Get the list of identity servers that the server proxied bind
+        requests to for given user and threepid
+
+        Args:
+            user_id (str)
+            medium (str)
+            address (str)
+
+        Returns:
+            Deferred[list[str]]: Resolves to a list of identity servers
+        """
+        return self._simple_select_onecol(
+            table="user_threepid_id_server",
+            keyvalues={
+                "user_id": user_id,
+                "medium": medium,
+                "address": address,
+            },
+            retcol="id_server",
+            desc="get_id_servers_user_bound",
+        )
+
 
 class RegistrationStore(
     RegistrationWorkerStore, background_updates.BackgroundUpdateStore
@@ -353,6 +430,10 @@ class RegistrationStore(
         # clear the background update.
         self.register_noop_background_update("refresh_tokens_device_index")
 
+        self.register_background_update_handler(
+            "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+        )
+
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
@@ -707,3 +788,34 @@ class RegistrationStore(
             allow_none=True,
             desc="get_users_pending_deactivation",
         )
+
+    @defer.inlineCallbacks
+    def _bg_user_threepids_grandfather(self, progress, batch_size):
+        """We now track which identity servers a user binds their 3PID to, so
+        we need to handle the case of existing bindings where we didn't track
+        this.
+
+        We do this by grandfathering in existing user threepids assuming that
+        they used one of the server configured trusted identity servers.
+        """
+
+        id_servers = set(self.config.trusted_third_party_id_servers)
+
+        def _bg_user_threepids_grandfather_txn(txn):
+            sql = """
+                INSERT INTO user_threepid_id_server
+                    (user_id, medium, address, id_server)
+                SELECT user_id, medium, address, ?
+                FROM user_threepids
+            """
+
+            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+        if id_servers:
+            yield self.runInteraction(
+                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+            )
+
+        yield self._end_background_update("user_threepids_grandfather")
+
+        defer.returnValue(1)