diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/rest/admin/users.py | 55 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 25 |
2 files changed, 53 insertions, 27 deletions
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 3c8a0c6883..c1a1ba645e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet): if not isinstance(deactivate, bool): raise SynapseError(400, "'deactivated' parameter is not of type boolean") - # convert into List[Tuple[str, str]] + # convert List[Dict[str, str]] into Set[Tuple[str, str]] if external_ids is not None: - new_external_ids = [] - for external_id in external_ids: - new_external_ids.append( - (external_id["auth_provider"], external_id["external_id"]) - ) + new_external_ids = { + (external_id["auth_provider"], external_id["external_id"]) + for external_id in external_ids + } + + # convert List[Dict[str, str]] into Set[Tuple[str, str]] + if threepids is not None: + new_threepids = { + (threepid["medium"], threepid["address"]) for threepid in threepids + } if user: # modify user if "displayname" in body: @@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet): ) if threepids is not None: - # remove old threepids from user - old_threepids = await self.store.user_get_threepids(user_id) - for threepid in old_threepids: + # get changed threepids (added and removed) + # convert List[Dict[str, Any]] into Set[Tuple[str, str]] + cur_threepids = { + (threepid["medium"], threepid["address"]) + for threepid in await self.store.user_get_threepids(user_id) + } + add_threepids = new_threepids - cur_threepids + del_threepids = cur_threepids - new_threepids + + # remove old threepids + for medium, address in del_threepids: try: await self.auth_handler.delete_threepid( - user_id, threepid["medium"], threepid["address"], None + user_id, medium, address, None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") - # add new threepids to user + # add new threepids current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in add_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if external_ids is not None: # get changed external_ids (added and removed) - cur_external_ids = await self.store.get_external_ids_by_user(user_id) - add_external_ids = set(new_external_ids) - set(cur_external_ids) - del_external_ids = set(cur_external_ids) - set(new_external_ids) + cur_external_ids = set( + await self.store.get_external_ids_by_user(user_id) + ) + add_external_ids = new_external_ids - cur_external_ids + del_external_ids = cur_external_ids - new_external_ids # remove old external_ids for auth_provider, external_id in del_external_ids: @@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet): if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in new_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if ( self.hs.config.email_enable_notifs @@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet): kind="email", app_id="m.email", app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], + device_display_name=address, + pushkey=address, lang=None, # We don't know a user's language here data={}, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c67bea81c6..469dd53e0c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return user_id - def get_user_id_by_threepid_txn(self, txn, medium, address): + def get_user_id_by_threepid_txn( + self, txn, medium: str, address: str + ) -> Optional[str]: """Returns user id from threepid Args: txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - str|None: user id or None if no user id/threepid mapping exists + user id, or None if no user id/threepid mapping exists """ ret = self.db_pool.simple_select_one_txn( txn, @@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return ret["user_id"] return None - async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + async def user_add_threepid( + self, + user_id: str, + medium: str, + address: str, + validated_at: int, + added_at: int, + ) -> None: await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id): + async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "user_get_threepids", ) - async def user_delete_threepid(self, user_id, medium, address) -> None: + async def user_delete_threepid( + self, user_id: str, medium: str, address: str + ) -> None: await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, |