diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc
new file mode 100644
index 0000000000..e6d314976e
--- /dev/null
+++ b/changelog.d/10627.misc
@@ -0,0 +1 @@
+Remove not needed database updates in modify user admin API.
\ No newline at end of file
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index 6a9335d6ec..60dc913915 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -21,11 +21,15 @@ It returns a JSON body like the following:
"threepids": [
{
"medium": "email",
- "address": "<user_mail_1>"
+ "address": "<user_mail_1>",
+ "added_at": 1586458409743,
+ "validated_at": 1586458409743
},
{
"medium": "email",
- "address": "<user_mail_2>"
+ "address": "<user_mail_2>",
+ "added_at": 1586458409743,
+ "validated_at": 1586458409743
}
],
"avatar_url": "<avatar_url>",
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3c8a0c6883..c1a1ba645e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
- # convert into List[Tuple[str, str]]
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None:
- new_external_ids = []
- for external_id in external_ids:
- new_external_ids.append(
- (external_id["auth_provider"], external_id["external_id"])
- )
+ new_external_ids = {
+ (external_id["auth_provider"], external_id["external_id"])
+ for external_id in external_ids
+ }
+
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+ if threepids is not None:
+ new_threepids = {
+ (threepid["medium"], threepid["address"]) for threepid in threepids
+ }
if user: # modify user
if "displayname" in body:
@@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
)
if threepids is not None:
- # remove old threepids from user
- old_threepids = await self.store.user_get_threepids(user_id)
- for threepid in old_threepids:
+ # get changed threepids (added and removed)
+ # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
+ cur_threepids = {
+ (threepid["medium"], threepid["address"])
+ for threepid in await self.store.user_get_threepids(user_id)
+ }
+ add_threepids = new_threepids - cur_threepids
+ del_threepids = cur_threepids - new_threepids
+
+ # remove old threepids
+ for medium, address in del_threepids:
try:
await self.auth_handler.delete_threepid(
- user_id, threepid["medium"], threepid["address"], None
+ user_id, medium, address, None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")
- # add new threepids to user
+ # add new threepids
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in add_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if external_ids is not None:
# get changed external_ids (added and removed)
- cur_external_ids = await self.store.get_external_ids_by_user(user_id)
- add_external_ids = set(new_external_ids) - set(cur_external_ids)
- del_external_ids = set(cur_external_ids) - set(new_external_ids)
+ cur_external_ids = set(
+ await self.store.get_external_ids_by_user(user_id)
+ )
+ add_external_ids = new_external_ids - cur_external_ids
+ del_external_ids = cur_external_ids - new_external_ids
# remove old external_ids
for auth_provider, external_id in del_external_ids:
@@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
if threepids is not None:
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in new_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if (
self.hs.config.email_enable_notifs
@@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
- device_display_name=threepid["address"],
- pushkey=threepid["address"],
+ device_display_name=address,
+ pushkey=address,
lang=None, # We don't know a user's language here
data={},
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6..469dd53e0c 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return user_id
- def get_user_id_by_threepid_txn(self, txn, medium, address):
+ def get_user_id_by_threepid_txn(
+ self, txn, medium: str, address: str
+ ) -> Optional[str]:
"""Returns user id from threepid
Args:
txn (cursor):
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- str|None: user id or None if no user id/threepid mapping exists
+ user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return ret["user_id"]
return None
- async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ async def user_add_threepid(
+ self,
+ user_id: str,
+ medium: str,
+ address: str,
+ validated_at: int,
+ added_at: int,
+ ) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id):
+ async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"user_get_threepids",
)
- async def user_delete_threepid(self, user_id, medium, address) -> None:
+ async def user_delete_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ef77275238..ee204c404b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
self.assertEqual(
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
)
+ self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test setting threepid for an other user.
"""
- # Delete old and add new threepid to user
+ # Add two threepids to user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob1@bob.bob"},
+ {"medium": "email", "address": "bob2@bob.bob"},
+ ],
+ },
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
+ # result does not always have the same sort order, therefore it becomes sorted
+ sorted_result = sorted(
+ channel.json_body["threepids"], key=lambda k: k["address"]
+ )
+ self.assertEqual("email", sorted_result[0]["medium"])
+ self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+ self.assertEqual("email", sorted_result[1]["medium"])
+ self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Set a new and remove a threepid
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob2@bob.bob"},
+ {"medium": "email", "address": "bob3@bob.bob"},
+ ],
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Remove threepids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": []},
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self._check_fields(channel.json_body)
def test_set_external_id(self):
"""
@@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
channel.json_body["external_ids"],
[
|