summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-09-13 15:20:49 +0100
committerRichard van der Hoff <richard@matrix.org>2019-09-13 16:01:46 +0100
commita8ac40445c98b9e1fc2538d7d4ec49c80b0298ac (patch)
treee506896dcaa02f826ffe0e5e1c859acabb290626 /synapse/storage/registration.py
parentMake the sample saml config closer to our standards (diff)
downloadsynapse-a8ac40445c98b9e1fc2538d7d4ec49c80b0298ac.tar.xz
Record mappings from saml users in an external table
We want to assign unique mxids to saml users based on an incrementing
suffix. For that to work, we need to record the allocated mxid in a separate
table.
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 55e4e84d71..1e3c2148f6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,6 +22,7 @@ from six import iterkeys
 from six.moves import range
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, ThreepidValidationError
@@ -337,6 +338,26 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def get_user_by_external_id(
+        self, auth_provider: str, external_id: str
+    ) -> str:
+        """Look up a user by their external auth id
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+
+        Returns:
+            str|None: the mxid of the user, or None if they are not known
+        """
+        return await self._simple_select_one_onecol(
+            table="user_external_ids",
+            keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="get_user_by_external_id",
+        )
+
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
@@ -848,6 +869,26 @@ class RegistrationStore(
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
+    def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> Deferred:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        return self._simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this