diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 51146c471d..afaa597f65 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1,6 +1,4 @@
-# Copyright 2015-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,16 +15,22 @@ import json
import os
import re
from email.parser import Parser
-from typing import Optional
+from typing import Dict, List, Optional
+from unittest.mock import Mock
import pkg_resources
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes, HttpResponseException
from synapse.appservice import ApplicationService
+from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
@@ -1040,3 +1044,195 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
self.assertIn(expected_email, threepids)
+
+
+class AccountStatusTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ account.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["experimental_features"] = {"msc3720_enabled": True}
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ self.requester = self.register_user("requester", "password")
+ self.requester_tok = self.login("requester", "password")
+ self.server_name = homeserver.config.server.server_name
+
+ def test_missing_mxid(self):
+ """Tests that not providing any MXID raises an error."""
+ self._test_status(
+ users=None,
+ expected_status_code=400,
+ expected_errcode=Codes.MISSING_PARAM,
+ )
+
+ def test_invalid_mxid(self):
+ """Tests that providing an invalid MXID raises an error."""
+ self._test_status(
+ users=["bad:test"],
+ expected_status_code=400,
+ expected_errcode=Codes.INVALID_PARAM,
+ )
+
+ def test_local_user_not_exists(self):
+ """Tests that the account status endpoints correctly reports that a user doesn't
+ exist.
+ """
+ user = "@unknown:" + self.hs.config.server.server_name
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ def test_local_user_exists(self):
+ """Tests that the account status endpoint correctly reports that a user doesn't
+ exist.
+ """
+ user = self.register_user("someuser", "password")
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ def test_local_user_deactivated(self):
+ """Tests that the account status endpoint correctly reports a deactivated user."""
+ user = self.register_user("someuser", "password")
+ self.get_success(
+ self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True)
+ )
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": True,
+ },
+ },
+ expected_failures=[],
+ )
+
+ def test_mixed_local_and_remote_users(self):
+ """Tests that if some users are remote the account status endpoint correctly
+ merges the remote responses with the local result.
+ """
+ # We use 3 users: one doesn't exist but belongs on the local homeserver, one is
+ # deactivated and belongs on one remote homeserver, and one belongs to another
+ # remote homeserver that didn't return any result (the federation code should
+ # mark that user as a failure).
+ users = [
+ "@unknown:" + self.hs.config.server.server_name,
+ "@deactivated:remote",
+ "@failed:otherremote",
+ "@bad:badremote",
+ ]
+
+ async def post_json(destination, path, data, *a, **kwa):
+ if destination == "remote":
+ return {
+ "account_statuses": {
+ users[1]: {
+ "exists": True,
+ "deactivated": True,
+ },
+ }
+ }
+ if destination == "otherremote":
+ return {}
+ if destination == "badremote":
+ # badremote tries to overwrite the status of a user that doesn't belong
+ # to it (i.e. users[1]) with false data, which Synapse is expected to
+ # ignore.
+ return {
+ "account_statuses": {
+ users[3]: {
+ "exists": False,
+ },
+ users[1]: {
+ "exists": False,
+ },
+ }
+ }
+
+ # Register a mock that will return the expected result depending on the remote.
+ self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+
+ # Check that we've got the correct response from the client-side endpoint.
+ self._test_status(
+ users=users,
+ expected_statuses={
+ users[0]: {
+ "exists": False,
+ },
+ users[1]: {
+ "exists": True,
+ "deactivated": True,
+ },
+ users[3]: {
+ "exists": False,
+ },
+ },
+ expected_failures=[users[2]],
+ )
+
+ def _test_status(
+ self,
+ users: Optional[List[str]],
+ expected_status_code: int = 200,
+ expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
+ expected_failures: Optional[List[str]] = None,
+ expected_errcode: Optional[str] = None,
+ ):
+ """Send a request to the account status endpoint and check that the response
+ matches with what's expected.
+
+ Args:
+ users: The account(s) to request the status of, if any. If set to None, no
+ `user_id` query parameter will be included in the request.
+ expected_status_code: The expected HTTP status code.
+ expected_statuses: The expected account statuses, if any.
+ expected_failures: The expected failures, if any.
+ expected_errcode: The expected Matrix error code, if any.
+ """
+ content = {}
+ if users is not None:
+ content["user_ids"] = users
+
+ channel = self.make_request(
+ method="POST",
+ path=self.url,
+ content=content,
+ access_token=self.requester_tok,
+ )
+
+ self.assertEqual(channel.code, expected_status_code)
+
+ if expected_statuses is not None:
+ self.assertEqual(channel.json_body["account_statuses"], expected_statuses)
+
+ if expected_failures is not None:
+ self.assertEqual(channel.json_body["failures"], expected_failures)
+
+ if expected_errcode is not None:
+ self.assertEqual(channel.json_body["errcode"], expected_errcode)
|