summary refs log tree commit diff
path: root/tests/rest/client/test_account.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_account.py')
-rw-r--r--tests/rest/client/test_account.py204
1 files changed, 200 insertions, 4 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 51146c471d..afaa597f65 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1,6 +1,4 @@
-# Copyright 2015-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,16 +15,22 @@ import json
 import os
 import re
 from email.parser import Parser
-from typing import Optional
+from typing import Dict, List, Optional
+from unittest.mock import Mock
 
 import pkg_resources
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
 from synapse.api.errors import Codes, HttpResponseException
 from synapse.appservice import ApplicationService
+from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeSite, make_request
@@ -1040,3 +1044,195 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
         self.assertIn(expected_email, threepids)
+
+
+class AccountStatusTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        account.register_servlets,
+        admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config["experimental_features"] = {"msc3720_enabled": True}
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+        self.requester = self.register_user("requester", "password")
+        self.requester_tok = self.login("requester", "password")
+        self.server_name = homeserver.config.server.server_name
+
+    def test_missing_mxid(self):
+        """Tests that not providing any MXID raises an error."""
+        self._test_status(
+            users=None,
+            expected_status_code=400,
+            expected_errcode=Codes.MISSING_PARAM,
+        )
+
+    def test_invalid_mxid(self):
+        """Tests that providing an invalid MXID raises an error."""
+        self._test_status(
+            users=["bad:test"],
+            expected_status_code=400,
+            expected_errcode=Codes.INVALID_PARAM,
+        )
+
+    def test_local_user_not_exists(self):
+        """Tests that the account status endpoints correctly reports that a user doesn't
+        exist.
+        """
+        user = "@unknown:" + self.hs.config.server.server_name
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    def test_local_user_exists(self):
+        """Tests that the account status endpoint correctly reports that a user doesn't
+        exist.
+        """
+        user = self.register_user("someuser", "password")
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    def test_local_user_deactivated(self):
+        """Tests that the account status endpoint correctly reports a deactivated user."""
+        user = self.register_user("someuser", "password")
+        self.get_success(
+            self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True)
+        )
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": True,
+                },
+            },
+            expected_failures=[],
+        )
+
+    def test_mixed_local_and_remote_users(self):
+        """Tests that if some users are remote the account status endpoint correctly
+        merges the remote responses with the local result.
+        """
+        # We use 3 users: one doesn't exist but belongs on the local homeserver, one is
+        # deactivated and belongs on one remote homeserver, and one belongs to another
+        # remote homeserver that didn't return any result (the federation code should
+        # mark that user as a failure).
+        users = [
+            "@unknown:" + self.hs.config.server.server_name,
+            "@deactivated:remote",
+            "@failed:otherremote",
+            "@bad:badremote",
+        ]
+
+        async def post_json(destination, path, data, *a, **kwa):
+            if destination == "remote":
+                return {
+                    "account_statuses": {
+                        users[1]: {
+                            "exists": True,
+                            "deactivated": True,
+                        },
+                    }
+                }
+            if destination == "otherremote":
+                return {}
+            if destination == "badremote":
+                # badremote tries to overwrite the status of a user that doesn't belong
+                # to it (i.e. users[1]) with false data, which Synapse is expected to
+                # ignore.
+                return {
+                    "account_statuses": {
+                        users[3]: {
+                            "exists": False,
+                        },
+                        users[1]: {
+                            "exists": False,
+                        },
+                    }
+                }
+
+        # Register a mock that will return the expected result depending on the remote.
+        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+
+        # Check that we've got the correct response from the client-side endpoint.
+        self._test_status(
+            users=users,
+            expected_statuses={
+                users[0]: {
+                    "exists": False,
+                },
+                users[1]: {
+                    "exists": True,
+                    "deactivated": True,
+                },
+                users[3]: {
+                    "exists": False,
+                },
+            },
+            expected_failures=[users[2]],
+        )
+
+    def _test_status(
+        self,
+        users: Optional[List[str]],
+        expected_status_code: int = 200,
+        expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
+        expected_failures: Optional[List[str]] = None,
+        expected_errcode: Optional[str] = None,
+    ):
+        """Send a request to the account status endpoint and check that the response
+        matches with what's expected.
+
+        Args:
+            users: The account(s) to request the status of, if any. If set to None, no
+                `user_id` query parameter will be included in the request.
+            expected_status_code: The expected HTTP status code.
+            expected_statuses: The expected account statuses, if any.
+            expected_failures: The expected failures, if any.
+            expected_errcode: The expected Matrix error code, if any.
+        """
+        content = {}
+        if users is not None:
+            content["user_ids"] = users
+
+        channel = self.make_request(
+            method="POST",
+            path=self.url,
+            content=content,
+            access_token=self.requester_tok,
+        )
+
+        self.assertEqual(channel.code, expected_status_code)
+
+        if expected_statuses is not None:
+            self.assertEqual(channel.json_body["account_statuses"], expected_statuses)
+
+        if expected_failures is not None:
+            self.assertEqual(channel.json_body["failures"], expected_failures)
+
+        if expected_errcode is not None:
+            self.assertEqual(channel.json_body["errcode"], expected_errcode)