summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/admin/users.py55
-rw-r--r--synapse/storage/databases/main/registration.py25
2 files changed, 53 insertions, 27 deletions
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3c8a0c6883..c1a1ba645e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
         if not isinstance(deactivate, bool):
             raise SynapseError(400, "'deactivated' parameter is not of type boolean")
 
-        # convert into List[Tuple[str, str]]
+        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
         if external_ids is not None:
-            new_external_ids = []
-            for external_id in external_ids:
-                new_external_ids.append(
-                    (external_id["auth_provider"], external_id["external_id"])
-                )
+            new_external_ids = {
+                (external_id["auth_provider"], external_id["external_id"])
+                for external_id in external_ids
+            }
+
+        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+        if threepids is not None:
+            new_threepids = {
+                (threepid["medium"], threepid["address"]) for threepid in threepids
+            }
 
         if user:  # modify user
             if "displayname" in body:
@@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
                 )
 
             if threepids is not None:
-                # remove old threepids from user
-                old_threepids = await self.store.user_get_threepids(user_id)
-                for threepid in old_threepids:
+                # get changed threepids (added and removed)
+                # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
+                cur_threepids = {
+                    (threepid["medium"], threepid["address"])
+                    for threepid in await self.store.user_get_threepids(user_id)
+                }
+                add_threepids = new_threepids - cur_threepids
+                del_threepids = cur_threepids - new_threepids
+
+                # remove old threepids
+                for medium, address in del_threepids:
                     try:
                         await self.auth_handler.delete_threepid(
-                            user_id, threepid["medium"], threepid["address"], None
+                            user_id, medium, address, None
                         )
                     except Exception:
                         logger.exception("Failed to remove threepids")
                         raise SynapseError(500, "Failed to remove threepids")
 
-                # add new threepids to user
+                # add new threepids
                 current_time = self.hs.get_clock().time_msec()
-                for threepid in threepids:
+                for medium, address in add_threepids:
                     await self.auth_handler.add_threepid(
-                        user_id, threepid["medium"], threepid["address"], current_time
+                        user_id, medium, address, current_time
                     )
 
             if external_ids is not None:
                 # get changed external_ids (added and removed)
-                cur_external_ids = await self.store.get_external_ids_by_user(user_id)
-                add_external_ids = set(new_external_ids) - set(cur_external_ids)
-                del_external_ids = set(cur_external_ids) - set(new_external_ids)
+                cur_external_ids = set(
+                    await self.store.get_external_ids_by_user(user_id)
+                )
+                add_external_ids = new_external_ids - cur_external_ids
+                del_external_ids = cur_external_ids - new_external_ids
 
                 # remove old external_ids
                 for auth_provider, external_id in del_external_ids:
@@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
 
             if threepids is not None:
                 current_time = self.hs.get_clock().time_msec()
-                for threepid in threepids:
+                for medium, address in new_threepids:
                     await self.auth_handler.add_threepid(
-                        user_id, threepid["medium"], threepid["address"], current_time
+                        user_id, medium, address, current_time
                     )
                     if (
                         self.hs.config.email_enable_notifs
@@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
                             kind="email",
                             app_id="m.email",
                             app_display_name="Email Notifications",
-                            device_display_name=threepid["address"],
-                            pushkey=threepid["address"],
+                            device_display_name=address,
+                            pushkey=address,
                             lang=None,  # We don't know a user's language here
                             data={},
                         )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6..469dd53e0c 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return user_id
 
-    def get_user_id_by_threepid_txn(self, txn, medium, address):
+    def get_user_id_by_threepid_txn(
+        self, txn, medium: str, address: str
+    ) -> Optional[str]:
         """Returns user id from threepid
 
         Args:
             txn (cursor):
-            medium (str): threepid medium e.g. email
-            address (str): threepid address e.g. me@example.com
+            medium: threepid medium e.g. email
+            address: threepid address e.g. me@example.com
 
         Returns:
-            str|None: user id or None if no user id/threepid mapping exists
+            user id, or None if no user id/threepid mapping exists
         """
         ret = self.db_pool.simple_select_one_txn(
             txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             return ret["user_id"]
         return None
 
-    async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+    async def user_add_threepid(
+        self,
+        user_id: str,
+        medium: str,
+        address: str,
+        validated_at: int,
+        added_at: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    async def user_get_threepids(self, user_id):
+    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
         return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "user_get_threepids",
         )
 
-    async def user_delete_threepid(self, user_id, medium, address) -> None:
+    async def user_delete_threepid(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
         await self.db_pool.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},