summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py75
1 files changed, 68 insertions, 7 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index da27ad76b6..241a7be51e 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,6 +22,7 @@ from six import iterkeys
 from six.moves import range
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -384,6 +385,26 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def get_user_by_external_id(
+        self, auth_provider: str, external_id: str
+    ) -> str:
+        """Look up a user by their external auth id
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+
+        Returns:
+            str|None: the mxid of the user, or None if they are not known
+        """
+        return await self._simple_select_one_onecol(
+            table="user_external_ids",
+            keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_user_by_external_id",
+        )
+
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
@@ -495,7 +516,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_user_id_by_threepid(self, medium, address, require_verified=False):
+    def get_user_id_by_threepid(self, medium, address):
         """Returns user id from threepid
 
         Args:
@@ -586,6 +607,26 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="add_user_bound_threepid",
         )
 
+    def user_get_bound_threepids(self, user_id):
+        """Get the threepids that a user has bound to an identity server through the homeserver
+        The homeserver remembers where binds to an identity server occurred. Using this
+        method can retrieve those threepids.
+
+        Args:
+            user_id (str): The ID of the user to retrieve threepids for
+
+        Returns:
+            Deferred[list[dict]]: List of dictionaries containing the following:
+                medium (str): The medium of the threepid (e.g "email")
+                address (str): The address of the threepid (e.g "bob@example.com")
+        """
+        return self._simple_select_list(
+            table="user_threepid_id_server",
+            keyvalues={"user_id": user_id},
+            retcols=["medium", "address"],
+            desc="user_get_bound_threepids",
+        )
+
     def remove_user_bound_threepid(self, user_id, medium, address, id_server):
         """The server proxied an unbind request to the given identity server on
         behalf of the given user, so we remove the mapping of threepid to
@@ -655,7 +696,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         self, medium, client_secret, address=None, sid=None, validated=True
     ):
         """Gets a session_id and last_send_attempt (if available) for a
-        client_secret/medium/(address|session_id) combo
+        combination of validation metadata
 
         Args:
             medium (str|None): The medium of the 3PID
@@ -824,7 +865,7 @@ class RegistrationStore(
             rows = self.cursor_to_dict(txn)
 
             if not rows:
-                return True
+                return True, 0
 
             rows_processed_nb = 0
 
@@ -840,18 +881,18 @@ class RegistrationStore(
             )
 
             if batch_size > len(rows):
-                return True
+                return True, len(rows)
             else:
-                return False
+                return False, len(rows)
 
-        end = yield self.runInteraction(
+        end, nb_processed = yield self.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
             yield self._end_background_update("users_set_deactivated_flag")
 
-        return batch_size
+        return nb_processed
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -1012,6 +1053,26 @@ class RegistrationStore(
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> Deferred:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        return self._simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this