summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/registration.py95
1 files changed, 90 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0ab56d8a07..37d47aa823 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,11 @@ import attr
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Cursor
@@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 logger = logging.getLogger(__name__)
 
 
+class ExternalIDReuseException(Exception):
+    """Exception if writing an external id for a user fails,
+    because this external id is given to an other user."""
+
+    pass
+
+
 @attr.s(frozen=True, slots=True)
 class TokenLookupResult:
     """Result of looking up an access token.
@@ -588,24 +599,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
             user_id: complete mxid that it is mapped to
+        Raises:
+            ExternalIDReuseException if the new external_id could not be mapped.
         """
-        await self.db_pool.simple_insert(
+
+        try:
+            await self.db_pool.runInteraction(
+                "record_user_external_id",
+                self._record_user_external_id_txn,
+                auth_provider,
+                external_id,
+                user_id,
+            )
+        except self.database_engine.module.IntegrityError:
+            raise ExternalIDReuseException()
+
+    def _record_user_external_id_txn(
+        self,
+        txn: LoggingTransaction,
+        auth_provider: str,
+        external_id: str,
+        user_id: str,
+    ) -> None:
+
+        self.db_pool.simple_insert_txn(
+            txn,
             table="user_external_ids",
             values={
                 "auth_provider": auth_provider,
                 "external_id": external_id,
                 "user_id": user_id,
             },
-            desc="record_user_external_id",
         )
 
     async def remove_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
     ) -> None:
         """Remove a mapping from an external user id to a mxid
-
         If the mapping is not found, this method does nothing.
-
         Args:
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
@@ -621,6 +652,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="remove_user_external_id",
         )
 
+    async def replace_user_external_id(
+        self,
+        record_external_ids: List[Tuple[str, str]],
+        user_id: str,
+    ) -> None:
+        """Replace mappings from external user ids to a mxid in a single transaction.
+        All mappings are deleted and the new ones are created.
+
+        Args:
+            record_external_ids:
+                List with tuple of auth_provider and external_id to record
+            user_id: complete mxid that it is mapped to
+        Raises:
+            ExternalIDReuseException if the new external_id could not be mapped.
+        """
+
+        def _remove_user_external_ids_txn(
+            txn: LoggingTransaction,
+            user_id: str,
+        ) -> None:
+            """Remove all mappings from external user ids to a mxid
+            If these mappings are not found, this method does nothing.
+
+            Args:
+                user_id: complete mxid that it is mapped to
+            """
+
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="user_external_ids",
+                keyvalues={"user_id": user_id},
+            )
+
+        def _replace_user_external_id_txn(
+            txn: LoggingTransaction,
+        ):
+            _remove_user_external_ids_txn(txn, user_id)
+
+            for auth_provider, external_id in record_external_ids:
+                self._record_user_external_id_txn(
+                    txn,
+                    auth_provider,
+                    external_id,
+                    user_id,
+                )
+
+        try:
+            await self.db_pool.runInteraction(
+                "replace_user_external_id",
+                _replace_user_external_id_txn,
+            )
+        except self.database_engine.module.IntegrityError:
+            raise ExternalIDReuseException()
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]: