summary refs log tree commit diff
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2021-10-21 10:52:32 +0200
committerGitHub <noreply@github.com>2021-10-21 09:52:32 +0100
commitef7fe09778ad672d6ed80fb2206cfbc11e6a9a5e (patch)
treecbdc5577b45755b7e0c828f4dc032c7603a159b4
parentUpdate `sign_json` to support inline key config (#11139) (diff)
downloadsynapse-ef7fe09778ad672d6ed80fb2206cfbc11e6a9a5e.tar.xz
Fix setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped (#11051)
Fixes #10846
-rw-r--r--changelog.d/11051.bugfix1
-rw-r--r--synapse/rest/admin/users.py47
-rw-r--r--synapse/storage/databases/main/registration.py95
-rw-r--r--tests/rest/admin/test_user.py215
4 files changed, 321 insertions, 37 deletions
diff --git a/changelog.d/11051.bugfix b/changelog.d/11051.bugfix
new file mode 100644
index 0000000000..63126843d2
--- /dev/null
+++ b/changelog.d/11051.bugfix
@@ -0,0 +1 @@
+Fix a bug where setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped.
\ No newline at end of file
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f20aa65301..c0bebc3cf0 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
     assert_user_is_admin,
 )
 from synapse.rest.client._base import client_patterns
+from synapse.storage.databases.main.registration import ExternalIDReuseException
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.types import JsonDict, UserID
 
@@ -228,12 +229,12 @@ class UserRestServletV2(RestServlet):
         if not isinstance(deactivate, bool):
             raise SynapseError(400, "'deactivated' parameter is not of type boolean")
 
-        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+        # convert List[Dict[str, str]] into List[Tuple[str, str]]
         if external_ids is not None:
-            new_external_ids = {
+            new_external_ids = [
                 (external_id["auth_provider"], external_id["external_id"])
                 for external_id in external_ids
-            }
+            ]
 
         # convert List[Dict[str, str]] into Set[Tuple[str, str]]
         if threepids is not None:
@@ -275,28 +276,13 @@ class UserRestServletV2(RestServlet):
                     )
 
             if external_ids is not None:
-                # get changed external_ids (added and removed)
-                cur_external_ids = set(
-                    await self.store.get_external_ids_by_user(user_id)
-                )
-                add_external_ids = new_external_ids - cur_external_ids
-                del_external_ids = cur_external_ids - new_external_ids
-
-                # remove old external_ids
-                for auth_provider, external_id in del_external_ids:
-                    await self.store.remove_user_external_id(
-                        auth_provider,
-                        external_id,
-                        user_id,
-                    )
-
-                # add new external_ids
-                for auth_provider, external_id in add_external_ids:
-                    await self.store.record_user_external_id(
-                        auth_provider,
-                        external_id,
+                try:
+                    await self.store.replace_user_external_id(
+                        new_external_ids,
                         user_id,
                     )
+                except ExternalIDReuseException:
+                    raise SynapseError(409, "External id is already in use.")
 
             if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
@@ -384,12 +370,15 @@ class UserRestServletV2(RestServlet):
                         )
 
             if external_ids is not None:
-                for auth_provider, external_id in new_external_ids:
-                    await self.store.record_user_external_id(
-                        auth_provider,
-                        external_id,
-                        user_id,
-                    )
+                try:
+                    for auth_provider, external_id in new_external_ids:
+                        await self.store.record_user_external_id(
+                            auth_provider,
+                            external_id,
+                            user_id,
+                        )
+                except ExternalIDReuseException:
+                    raise SynapseError(409, "External id is already in use.")
 
             if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0ab56d8a07..37d47aa823 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,11 @@ import attr
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Cursor
@@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 logger = logging.getLogger(__name__)
 
 
+class ExternalIDReuseException(Exception):
+    """Exception if writing an external id for a user fails,
+    because this external id is given to an other user."""
+
+    pass
+
+
 @attr.s(frozen=True, slots=True)
 class TokenLookupResult:
     """Result of looking up an access token.
@@ -588,24 +599,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
             user_id: complete mxid that it is mapped to
+        Raises:
+            ExternalIDReuseException if the new external_id could not be mapped.
         """
-        await self.db_pool.simple_insert(
+
+        try:
+            await self.db_pool.runInteraction(
+                "record_user_external_id",
+                self._record_user_external_id_txn,
+                auth_provider,
+                external_id,
+                user_id,
+            )
+        except self.database_engine.module.IntegrityError:
+            raise ExternalIDReuseException()
+
+    def _record_user_external_id_txn(
+        self,
+        txn: LoggingTransaction,
+        auth_provider: str,
+        external_id: str,
+        user_id: str,
+    ) -> None:
+
+        self.db_pool.simple_insert_txn(
+            txn,
             table="user_external_ids",
             values={
                 "auth_provider": auth_provider,
                 "external_id": external_id,
                 "user_id": user_id,
             },
-            desc="record_user_external_id",
         )
 
     async def remove_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
     ) -> None:
         """Remove a mapping from an external user id to a mxid
-
         If the mapping is not found, this method does nothing.
-
         Args:
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
@@ -621,6 +652,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="remove_user_external_id",
         )
 
+    async def replace_user_external_id(
+        self,
+        record_external_ids: List[Tuple[str, str]],
+        user_id: str,
+    ) -> None:
+        """Replace mappings from external user ids to a mxid in a single transaction.
+        All mappings are deleted and the new ones are created.
+
+        Args:
+            record_external_ids:
+                List with tuple of auth_provider and external_id to record
+            user_id: complete mxid that it is mapped to
+        Raises:
+            ExternalIDReuseException if the new external_id could not be mapped.
+        """
+
+        def _remove_user_external_ids_txn(
+            txn: LoggingTransaction,
+            user_id: str,
+        ) -> None:
+            """Remove all mappings from external user ids to a mxid
+            If these mappings are not found, this method does nothing.
+
+            Args:
+                user_id: complete mxid that it is mapped to
+            """
+
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="user_external_ids",
+                keyvalues={"user_id": user_id},
+            )
+
+        def _replace_user_external_id_txn(
+            txn: LoggingTransaction,
+        ):
+            _remove_user_external_ids_txn(txn, user_id)
+
+            for auth_provider, external_id in record_external_ids:
+                self._record_user_external_id_txn(
+                    txn,
+                    auth_provider,
+                    external_id,
+                    user_id,
+                )
+
+        try:
+            await self.db_pool.runInteraction(
+                "replace_user_external_id",
+                _replace_user_external_id_txn,
+            )
+        except self.database_engine.module.IntegrityError:
+            raise ExternalIDReuseException()
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]:
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index c9e2754b09..839442ddba 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1180,9 +1180,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 self.other_user, device_id=None, valid_until_ms=None
             )
         )
-        self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
-            self.other_user
-        )
+        self.url_prefix = "/_synapse/admin/v2/users/%s"
+        self.url_other_user = self.url_prefix % self.other_user
 
     def test_requester_is_no_admin(self):
         """
@@ -1738,6 +1737,93 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self._check_fields(channel.json_body)
 
+    def test_set_duplicate_threepid(self):
+        """
+        Test setting the same threepid for a second user.
+        First user loses and second user gets mapping of this threepid.
+        """
+
+        # create a user to set a threepid
+        first_user = self.register_user("first_user", "pass")
+        url_first_user = self.url_prefix % first_user
+
+        # Add threepid to first user
+        channel = self.make_request(
+            "PUT",
+            url_first_user,
+            access_token=self.admin_user_tok,
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob1@bob.bob"},
+                ],
+            },
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(first_user, channel.json_body["name"])
+        self.assertEqual(1, len(channel.json_body["threepids"]))
+        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+        self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"])
+        self._check_fields(channel.json_body)
+
+        # Add threepids to other user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob2@bob.bob"},
+                ],
+            },
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(1, len(channel.json_body["threepids"]))
+        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+        self._check_fields(channel.json_body)
+
+        # Add two new threepids to other user
+        # one is used by first_user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob1@bob.bob"},
+                    {"medium": "email", "address": "bob3@bob.bob"},
+                ],
+            },
+        )
+
+        # other user has this two threepids
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
+        # result does not always have the same sort order, therefore it becomes sorted
+        sorted_result = sorted(
+            channel.json_body["threepids"], key=lambda k: k["address"]
+        )
+        self.assertEqual("email", sorted_result[0]["medium"])
+        self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+        self.assertEqual("email", sorted_result[1]["medium"])
+        self.assertEqual("bob3@bob.bob", sorted_result[1]["address"])
+        self._check_fields(channel.json_body)
+
+        # first_user has no threepid anymore
+        channel = self.make_request(
+            "GET",
+            url_first_user,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(first_user, channel.json_body["name"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self._check_fields(channel.json_body)
+
     def test_set_external_id(self):
         """
         Test setting external id for an other user.
@@ -1836,6 +1922,129 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(0, len(channel.json_body["external_ids"]))
 
+    def test_set_duplicate_external_id(self):
+        """
+        Test that setting the same external id for a second user fails and
+        external id from user must not be changed.
+        """
+
+        # create a user to use an external id
+        first_user = self.register_user("first_user", "pass")
+        url_first_user = self.url_prefix % first_user
+
+        # Add an external id to first user
+        channel = self.make_request(
+            "PUT",
+            url_first_user,
+            access_token=self.admin_user_tok,
+            content={
+                "external_ids": [
+                    {
+                        "external_id": "external_id1",
+                        "auth_provider": "auth_provider",
+                    },
+                ],
+            },
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(first_user, channel.json_body["name"])
+        self.assertEqual(1, len(channel.json_body["external_ids"]))
+        self.assertEqual(
+            "external_id1", channel.json_body["external_ids"][0]["external_id"]
+        )
+        self.assertEqual(
+            "auth_provider", channel.json_body["external_ids"][0]["auth_provider"]
+        )
+        self._check_fields(channel.json_body)
+
+        # Add an external id to other user
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={
+                "external_ids": [
+                    {
+                        "external_id": "external_id2",
+                        "auth_provider": "auth_provider",
+                    },
+                ],
+            },
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(1, len(channel.json_body["external_ids"]))
+        self.assertEqual(
+            "external_id2", channel.json_body["external_ids"][0]["external_id"]
+        )
+        self.assertEqual(
+            "auth_provider", channel.json_body["external_ids"][0]["auth_provider"]
+        )
+        self._check_fields(channel.json_body)
+
+        # Add two new external_ids to other user
+        # one is used by first
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={
+                "external_ids": [
+                    {
+                        "external_id": "external_id1",
+                        "auth_provider": "auth_provider",
+                    },
+                    {
+                        "external_id": "external_id3",
+                        "auth_provider": "auth_provider",
+                    },
+                ],
+            },
+        )
+
+        # must fail
+        self.assertEqual(409, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual("External id is already in use.", channel.json_body["error"])
+
+        # other user must not changed
+        channel = self.make_request(
+            "GET",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(1, len(channel.json_body["external_ids"]))
+        self.assertEqual(
+            "external_id2", channel.json_body["external_ids"][0]["external_id"]
+        )
+        self.assertEqual(
+            "auth_provider", channel.json_body["external_ids"][0]["auth_provider"]
+        )
+        self._check_fields(channel.json_body)
+
+        # first user must not changed
+        channel = self.make_request(
+            "GET",
+            url_first_user,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(first_user, channel.json_body["name"])
+        self.assertEqual(1, len(channel.json_body["external_ids"]))
+        self.assertEqual(
+            "external_id1", channel.json_body["external_ids"][0]["external_id"]
+        )
+        self.assertEqual(
+            "auth_provider", channel.json_body["external_ids"][0]["auth_provider"]
+        )
+        self._check_fields(channel.json_body)
+
     def test_deactivate_user(self):
         """
         Test deactivating another user.