diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 492adb6160..cf71bbb461 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -4854,3 +4854,59 @@ class UsersByThreePidTestCase(unittest.HomeserverTestCase):
{"user_id": self.other_user},
channel.json_body,
)
+
+
+class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ @staticmethod
+ def url(user: str) -> str:
+ template = (
+ "/_synapse/admin/v1/users/{}/_allow_cross_signing_replacement_without_uia"
+ )
+ return template.format(urllib.parse.quote(user))
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+
+ def test_error_cases(self) -> None:
+ fake_user = "@bums:other"
+ channel = self.make_request(
+ "POST", self.url(fake_user), access_token=self.admin_user_tok
+ )
+ # Fail: user doesn't exist
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+
+ channel = self.make_request(
+ "POST", self.url(self.other_user), access_token=self.admin_user_tok
+ )
+ # Fail: user exists, but has no master cross-signing key
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+
+ def test_success(self) -> None:
+ # Upload a master key.
+ dummy_key = {"keys": {"a": "b"}}
+ self.get_success(
+ self.store.set_e2e_cross_signing_key(self.other_user, "master", dummy_key)
+ )
+
+ channel = self.make_request(
+ "POST", self.url(self.other_user), access_token=self.admin_user_tok
+ )
+ # Success!
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Should now find that the key exists.
+ _, timestamp = self.get_success(
+ self.store.get_master_cross_signing_key_updatable_before(self.other_user)
+ )
+ assert timestamp is not None
+ self.assertGreater(timestamp, self.clock.time_msec())
|