summary refs log tree commit diff
path: root/tests/rest/admin/test_user.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/rest/admin/test_user.py1329
1 files changed, 944 insertions, 385 deletions
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py

index 16bb4349f5..412718f06c 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -21,9 +21,12 @@ import hashlib import hmac +import json import os +import time import urllib.parse from binascii import unhexlify +from http import HTTPStatus from typing import Dict, List, Optional from unittest.mock import AsyncMock, Mock, patch @@ -33,7 +36,13 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin -from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes +from synapse.api.constants import ( + ApprovalNoticeMedium, + EventContentFields, + EventTypes, + LoginType, + UserTypes, +) from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.media.filepath import MediaFilePaths @@ -42,6 +51,7 @@ from synapse.rest.client import ( devices, login, logout, + media, profile, register, room, @@ -54,7 +64,9 @@ from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.test_utils import SMALL_PNG +from tests.test_utils.event_injection import inject_event from tests.unittest import override_config @@ -316,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) + @override_config( + { + "user_types": { + "extra_user_types": ["extra1", "extra2"], + } + } + ) + def test_extra_user_type(self) -> None: + """ + Check that the extra user type can be used when registering a user. + """ + + def nonce_mac(user_type: str) -> tuple[str, str]: + """ + Get a nonce and the expected HMAC for that nonce. + """ + channel = self.make_request("GET", self.url) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update( + nonce.encode("ascii") + + b"\x00alice\x00abc123\x00notadmin\x00" + + user_type.encode("ascii") + ) + want_mac_str = want_mac.hexdigest() + + return nonce, want_mac_str + + nonce, mac = nonce_mac("extra1") + # Valid user_type + body = { + "nonce": nonce, + "username": "alice", + "password": "abc123", + "user_type": "extra1", + "mac": mac, + } + channel = self.make_request("POST", self.url, body) + self.assertEqual(200, channel.code, msg=channel.json_body) + + nonce, mac = nonce_mac("extra3") + # Invalid user_type + body = { + "nonce": nonce, + "username": "alice", + "password": "abc123", + "user_type": "extra3", + "mac": mac, + } + channel = self.make_request("POST", self.url, body) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Invalid user type", channel.json_body["error"]) + def test_displayname(self) -> None: """ Test that displayname of new user is set @@ -715,7 +782,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - # unkown order_by + # unknown order_by channel = self.make_request( "GET", self.url + "?order_by=bar", @@ -1174,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase): not_user_types=["custom"], ) + @override_config( + { + "user_types": { + "extra_user_types": ["extra1", "extra2"], + } + } + ) + def test_filter_not_user_types_with_extra(self) -> None: + """Tests that the endpoint handles the not_user_types param when extra_user_types are configured""" + + regular_user_id = self.register_user("normalo", "secret") + + extra1_user_id = self.register_user("extra1", "secret") + self.make_request( + "PUT", + "/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id), + {"user_type": "extra1"}, + access_token=self.admin_user_tok, + ) + + def test_user_type( + expected_user_ids: List[str], not_user_types: Optional[List[str]] = None + ) -> None: + """Runs a test for the not_user_types param + Args: + expected_user_ids: Ids of the users that are expected to be returned + not_user_types: List of values for the not_user_types param + """ + + user_type_query = "" + + if not_user_types is not None: + user_type_query = "&".join( + [f"not_user_type={u}" for u in not_user_types] + ) + + test_url = f"{self.url}?{user_type_query}" + channel = self.make_request( + "GET", + test_url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code) + self.assertEqual(channel.json_body["total"], len(expected_user_ids)) + self.assertEqual( + expected_user_ids, + [u["name"] for u in channel.json_body["users"]], + ) + + # Request without user_types → all users expected + test_user_type([self.admin_user, extra1_user_id, regular_user_id]) + + # Request and exclude extra1 user type + test_user_type( + [self.admin_user, regular_user_id], + not_user_types=["extra1"], + ) + + # Request and exclude extra1 and extra2 user types + test_user_type( + [self.admin_user, regular_user_id], + not_user_types=["extra1", "extra2"], + ) + + # Request and exclude empty user types → only expected the extra1 user + test_user_type([extra1_user_id], not_user_types=[""]) + + # Request and exclude an unregistered type → expect all users + test_user_type( + [self.admin_user, extra1_user_id, regular_user_id], + not_user_types=["extra3"], + ) + def test_erasure_status(self) -> None: # Create a new user. user_id = self.register_user("eraseme", "eraseme") @@ -1411,9 +1552,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): UserID.from_string("@user:test"), "mxc://servername/mediaid" ) ) - self.get_success( - self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) - ) def test_no_auth(self) -> None: """ @@ -1500,7 +1638,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) self.assertFalse(channel.json_body["erased"]) @@ -1525,7 +1662,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) - self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) self.assertTrue(channel.json_body["erased"]) @@ -1568,7 +1704,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) @@ -1592,7 +1727,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) - self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User1", channel.json_body["displayname"]) @@ -1622,7 +1756,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) @@ -1646,7 +1779,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) - self.assertEqual(0, len(channel.json_body["threepids"])) self.assertIsNone(channel.json_body["avatar_url"]) self.assertIsNone(channel.json_body["displayname"]) @@ -1817,25 +1949,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - # threepids not valid - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={"threepids": {"medium": "email", "wrong_address": "id"}}, - ) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={"threepids": {"address": "value"}}, - ) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_get_user(self) -> None: """ Test a simple get of a user. @@ -1890,8 +2003,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1906,8 +2017,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertTrue(channel.json_body["admin"]) self.assertFalse(channel.json_body["is_guest"]) self.assertFalse(channel.json_body["deactivated"]) @@ -1945,9 +2054,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual( "external_id1", channel.json_body["external_ids"][0]["external_id"] ) @@ -1969,8 +2075,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertFalse(channel.json_body["admin"]) self.assertFalse(channel.json_body["is_guest"]) self.assertFalse(channel.json_body["deactivated"]) @@ -2062,123 +2166,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) - @override_config( - { - "email": { - "enable_notifs": True, - "notif_for_new_users": True, - "notif_from": "test@example.com", - }, - "public_baseurl": "https://example.com", - } - ) - def test_create_user_email_notif_for_new_users(self) -> None: - """ - Check that a new regular user is created successfully and - got an email pusher. - """ - url = self.url_prefix % "@bob:test" - - # Create user - body = { - "password": "abc123", - # Note that the given email is not in canonical form. - "threepids": [{"medium": "email", "address": "Bob@bob.bob"}], - } - - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - content=body, - ) - - self.assertEqual(201, channel.code, msg=channel.json_body) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - - pushers = list( - self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) - ) - self.assertEqual(len(pushers), 1) - self.assertEqual("@bob:test", pushers[0].user_name) - - @override_config( - { - "email": { - "enable_notifs": False, - "notif_for_new_users": False, - "notif_from": "test@example.com", - }, - "public_baseurl": "https://example.com", - } - ) - def test_create_user_email_no_notif_for_new_users(self) -> None: - """ - Check that a new regular user is created successfully and - got not an email pusher. - """ - url = self.url_prefix % "@bob:test" - - # Create user - body = { - "password": "abc123", - "threepids": [{"medium": "email", "address": "bob@bob.bob"}], - } - - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - content=body, - ) - - self.assertEqual(201, channel.code, msg=channel.json_body) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - - pushers = list( - self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) - ) - self.assertEqual(len(pushers), 0) - - @override_config( - { - "email": { - "enable_notifs": True, - "notif_for_new_users": True, - "notif_from": "test@example.com", - }, - "public_baseurl": "https://example.com", - } - ) - def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None: - """ - Check that a new regular user is created successfully when they have a msisdn - threepid and email notif_for_new_users is set to True. - """ - url = self.url_prefix % "@bob:test" - - # Create user - body = { - "password": "abc123", - "threepids": [{"medium": "msisdn", "address": "1234567890"}], - } - - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - content=body, - ) - - self.assertEqual(201, channel.code, msg=channel.json_body) - self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) - def test_set_password(self) -> None: """ Test setting a new password for another user. @@ -2222,89 +2209,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) - - def test_set_threepid(self) -> None: - """ - Test setting threepid for an other user. - """ - - # Add two threepids to user - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={ - "threepids": [ - {"medium": "email", "address": "bob1@bob.bob"}, - {"medium": "email", "address": "bob2@bob.bob"}, - ], - }, - ) - - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(2, len(channel.json_body["threepids"])) - # result does not always have the same sort order, therefore it becomes sorted - sorted_result = sorted( - channel.json_body["threepids"], key=lambda k: k["address"] - ) - self.assertEqual("email", sorted_result[0]["medium"]) - self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) - self.assertEqual("email", sorted_result[1]["medium"]) - self.assertEqual("bob2@bob.bob", sorted_result[1]["address"]) - self._check_fields(channel.json_body) - - # Set a new and remove a threepid - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={ - "threepids": [ - {"medium": "email", "address": "bob2@bob.bob"}, - {"medium": "email", "address": "bob3@bob.bob"}, - ], - }, - ) - - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(2, len(channel.json_body["threepids"])) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) - self._check_fields(channel.json_body) - - # Get user - channel = self.make_request( - "GET", - self.url_other_user, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(2, len(channel.json_body["threepids"])) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) - self._check_fields(channel.json_body) - - # Remove threepids - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={"threepids": []}, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(0, len(channel.json_body["threepids"])) - self._check_fields(channel.json_body) - - def test_set_duplicate_threepid(self) -> None: """ Test setting the same threepid for a second user. First user loses and second user gets mapping of this threepid. @@ -2328,9 +2232,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) - self.assertEqual(1, len(channel.json_body["threepids"])) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"]) self._check_fields(channel.json_body) # Add threepids to other user @@ -2347,9 +2248,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(1, len(channel.json_body["threepids"])) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) self._check_fields(channel.json_body) # Add two new threepids to other user @@ -2369,15 +2267,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): # other user has this two threepids self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(2, len(channel.json_body["threepids"])) - # result does not always have the same sort order, therefore it becomes sorted - sorted_result = sorted( - channel.json_body["threepids"], key=lambda k: k["address"] - ) - self.assertEqual("email", sorted_result[0]["medium"]) - self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) - self.assertEqual("email", sorted_result[1]["medium"]) - self.assertEqual("bob3@bob.bob", sorted_result[1]["address"]) self._check_fields(channel.json_body) # first_user has no threepid anymore @@ -2388,7 +2277,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) - self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) def test_set_external_id(self) -> None: @@ -2623,9 +2511,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): UserID.from_string("@user:test"), "mxc://servername/mediaid" ) ) - self.get_success( - self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0) - ) # Get user channel = self.make_request( @@ -2637,7 +2522,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) - self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -2652,7 +2536,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) - self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -2671,7 +2554,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) - self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -2965,22 +2847,18 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) - def test_set_user_type(self) -> None: - """ - Test changing user type. - """ - - # Set to support type + def set_user_type(self, user_type: Optional[str]) -> None: + # Set to user_type channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content={"user_type": UserTypes.SUPPORT}, + content={"user_type": user_type}, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + self.assertEqual(user_type, channel.json_body["user_type"]) # Get user channel = self.make_request( @@ -2991,30 +2869,44 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + self.assertEqual(user_type, channel.json_body["user_type"]) + + def test_set_user_type(self) -> None: + """ + Test changing user type. + """ + + # Set to support type + self.set_user_type(UserTypes.SUPPORT) # Change back to a regular user - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={"user_type": None}, - ) + self.set_user_type(None) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertIsNone(channel.json_body["user_type"]) + @override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}}) + def test_set_user_type_with_extras(self) -> None: + """ + Test changing user type with extra_user_types configured. + """ - # Get user + # Check that we can still set to support type + self.set_user_type(UserTypes.SUPPORT) + + # Check that we can set to an extra user type + self.set_user_type("extra2") + + # Change back to a regular user + self.set_user_type(None) + + # Try setting to invalid type channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "extra3"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertIsNone(channel.json_body["user_type"]) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Invalid user type", channel.json_body["error"]) def test_accidental_deactivation_prevention(self) -> None: """ @@ -3204,7 +3096,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): content: Content dictionary to check """ self.assertIn("displayname", content) - self.assertIn("threepids", content) self.assertIn("avatar_url", content) self.assertIn("admin", content) self.assertIn("deactivated", content) @@ -3217,6 +3108,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("consent_ts", content) self.assertIn("external_ids", content) self.assertIn("last_seen_ts", content) + self.assertIn("suspended", content) # This key was removed intentionally. Ensure it is not accidentally re-included. self.assertNotIn("password_hash", content) @@ -3513,6 +3405,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, + media.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -3692,7 +3585,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): @parameterized.expand(["GET", "DELETE"]) def test_invalid_parameter(self, method: str) -> None: """If parameters are invalid, an error is returned.""" - # unkown order_by + # unknown order_by channel = self.make_request( method, self.url + "?order_by=bar", @@ -3887,9 +3780,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): image_data1 = SMALL_PNG # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B image_data2 = unhexlify( - b"47494638376101000100800100000000" - b"ffffff2c00000000010001000002024c" - b"01003b" + b"47494638376101000100800100000000ffffff2c00000000010001000002024c01003b" ) # Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B image_data3 = unhexlify( @@ -4019,7 +3910,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Try to access a media and to create `last_access_ts` channel = self.make_request( "GET", - f"/_matrix/media/v3/download/{server_and_media_id}", + f"/_matrix/client/v1/media/download/{server_and_media_id}", shorthand=False, access_token=user_token, ) @@ -4858,100 +4749,6 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase): ) -class UsersByThreePidTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.get_success( - self.store.user_add_threepid( - self.other_user, "email", "user@email.com", 1, 1 - ) - ) - self.get_success( - self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1) - ) - - def test_no_auth(self) -> None: - """Try to look up a user without authentication.""" - url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" - - channel = self.make_request( - "GET", - url, - ) - - self.assertEqual(401, channel.code, msg=channel.json_body) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - def test_medium_does_not_exist(self) -> None: - """Tests that both a lookup for a medium that does not exist and a user that - doesn't exist with that third party ID returns a 404""" - # test for unknown medium - url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key" - - channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - # test for unknown user with a known medium - url = "/_synapse/admin/v1/threepid/email/users/unknown" - - channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_success(self) -> None: - """Tests a successful medium + address lookup""" - # test for email medium with encoded value of user@email.com - url = "/_synapse/admin/v1/threepid/email/users/user%40email.com" - - channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual( - {"user_id": self.other_user}, - channel.json_body, - ) - - # test for msidn medium with encoded value of +1-12345678 - url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678" - - channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual( - {"user_id": self.other_user}, - channel.json_body, - ) - - class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -5024,7 +4821,6 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main - @override_config({"experimental_features": {"msc3823_account_suspension": True}}) def test_suspend_user(self) -> None: # test that suspending user works channel = self.make_request( @@ -5089,3 +4885,766 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase): res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user)) self.assertEqual(True, res5) + + +class UserRedactionTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin = self.register_user("thomas", "pass", True) + self.admin_tok = self.login("thomas", "pass") + + self.bad_user = self.register_user("teresa", "pass") + self.bad_user_tok = self.login("teresa", "pass") + + self.store = hs.get_datastores().main + + self.spam_checker = hs.get_module_api_callbacks().spam_checker + + # create rooms - room versions 11+ store the `redacts` key in content while + # earlier ones don't so we use a mix of room versions + self.rm1 = self.helper.create_room_as( + self.admin, tok=self.admin_tok, room_version="7" + ) + self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok) + self.rm3 = self.helper.create_room_as( + self.admin, tok=self.admin_tok, room_version="11" + ) + + def test_redact_messages_all_rooms(self) -> None: + """ + Test that request to redact events in all rooms user is member of is successful + """ + # join rooms, send some messages + originals = [] + for rm in [self.rm1, self.rm2, self.rm3]: + join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) + originals.append(join["event_id"]) + for i in range(15): + event = {"body": f"hello{i}", "msgtype": "m.text"} + res = self.helper.send_event( + rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200 + ) + originals.append(res["event_id"]) + + # redact all events in all rooms + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": []}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + matched = [] + for rm in [self.rm1, self.rm2, self.rm3]: + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel = self.make_request( + "GET", + f"rooms/{rm}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + for event in channel.json_body["chunk"]: + for event_id in originals: + if ( + event["type"] == "m.room.redaction" + and event["redacts"] == event_id + ): + matched.append(event_id) + self.assertEqual(len(matched), len(originals)) + + def test_redact_messages_specific_rooms(self) -> None: + """ + Test that request to redact events in specified rooms user is member of is successful + """ + + originals = [] + for rm in [self.rm1, self.rm2, self.rm3]: + join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) + originals.append(join["event_id"]) + for i in range(15): + event = {"body": f"hello{i}", "msgtype": "m.text"} + res = self.helper.send_event( + rm, "m.room.message", event, tok=self.bad_user_tok + ) + originals.append(res["event_id"]) + + # redact messages in rooms 1 and 3 + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": [self.rm1, self.rm3]}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + # messages in requested rooms are redacted + for rm in [self.rm1, self.rm3]: + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel = self.make_request( + "GET", + f"rooms/{rm}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + matches = [] + for event in channel.json_body["chunk"]: + for event_id in originals: + if ( + event["type"] == "m.room.redaction" + and event["redacts"] == event_id + ): + matches.append((event_id, event)) + # we redacted 16 messages + self.assertEqual(len(matches), 16) + + channel = self.make_request( + "GET", f"rooms/{self.rm2}/messages?limit=50", access_token=self.admin_tok + ) + self.assertEqual(channel.code, 200) + + # messages in remaining room are not + for event in channel.json_body["chunk"]: + if event["type"] == "m.room.redaction": + self.fail("found redaction in room 2") + + def test_redact_status(self) -> None: + rm2_originals = [] + for rm in [self.rm1, self.rm2, self.rm3]: + join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) + if rm == self.rm2: + rm2_originals.append(join["event_id"]) + for i in range(5): + event = {"body": f"hello{i}", "msgtype": "m.text"} + res = self.helper.send_event( + rm, "m.room.message", event, tok=self.bad_user_tok + ) + if rm == self.rm2: + rm2_originals.append(res["event_id"]) + + # redact messages in rooms 1 and 3 + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": [self.rm1, self.rm3]}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + id = channel.json_body.get("redact_id") + + channel2 = self.make_request( + "GET", + f"/_synapse/admin/v1/user/redact_status/{id}", + access_token=self.admin_tok, + ) + self.assertEqual(channel2.code, 200) + self.assertEqual(channel2.json_body.get("status"), "complete") + self.assertEqual(channel2.json_body.get("failed_redactions"), {}) + + # mock that will cause persisting the redaction events to fail + async def check_event_for_spam(event: str) -> str: + return "spam" + + self.spam_checker.check_event_for_spam = check_event_for_spam # type: ignore + + channel3 = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": [self.rm2]}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + id = channel3.json_body.get("redact_id") + + channel4 = self.make_request( + "GET", + f"/_synapse/admin/v1/user/redact_status/{id}", + access_token=self.admin_tok, + ) + self.assertEqual(channel4.code, 200) + self.assertEqual(channel4.json_body.get("status"), "complete") + failed_redactions = channel4.json_body.get("failed_redactions") + assert failed_redactions is not None + matched = [] + for original in rm2_originals: + if failed_redactions.get(original) is not None: + matched.append(original) + self.assertEqual(len(matched), len(rm2_originals)) + + def test_admin_redact_works_if_user_kicked_or_banned(self) -> None: + originals1 = [] + originals2 = [] + for rm in [self.rm1, self.rm2, self.rm3]: + join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) + if rm in [self.rm1, self.rm3]: + originals1.append(join["event_id"]) + else: + originals2.append(join["event_id"]) + for i in range(5): + event = {"body": f"hello{i}", "msgtype": "m.text"} + res = self.helper.send_event( + rm, "m.room.message", event, tok=self.bad_user_tok + ) + if rm in [self.rm1, self.rm3]: + originals1.append(res["event_id"]) + else: + originals2.append(res["event_id"]) + + # kick user from rooms 1 and 3 + for r in [self.rm1, self.rm3]: + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{r}/kick", + content={"reason": "being a bummer", "user_id": self.bad_user}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # redact messages in room 1 and 3 + channel1 = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": [self.rm1, self.rm3]}, + access_token=self.admin_tok, + ) + self.assertEqual(channel1.code, 200) + id = channel1.json_body.get("redact_id") + + # check that there were no failed redactions in room 1 and 3 + channel2 = self.make_request( + "GET", + f"/_synapse/admin/v1/user/redact_status/{id}", + access_token=self.admin_tok, + ) + self.assertEqual(channel2.code, 200) + self.assertEqual(channel2.json_body.get("status"), "complete") + failed_redactions = channel2.json_body.get("failed_redactions") + self.assertEqual(failed_redactions, {}) + + # double check + for rm in [self.rm1, self.rm3]: + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel3 = self.make_request( + "GET", + f"rooms/{rm}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel3.code, 200) + + matches = [] + for event in channel3.json_body["chunk"]: + for event_id in originals1: + if ( + event["type"] == "m.room.redaction" + and event["redacts"] == event_id + ): + matches.append((event_id, event)) + # we redacted 6 messages + self.assertEqual(len(matches), 6) + + # ban user from room 2 + channel4 = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.rm2}/ban", + content={"reason": "being a bummer", "user_id": self.bad_user}, + access_token=self.admin_tok, + ) + self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result) + + # make a request to ban all user's messages + channel5 = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": []}, + access_token=self.admin_tok, + ) + self.assertEqual(channel5.code, 200) + id2 = channel5.json_body.get("redact_id") + + # check that there were no failed redactions in room 2 + channel6 = self.make_request( + "GET", + f"/_synapse/admin/v1/user/redact_status/{id2}", + access_token=self.admin_tok, + ) + self.assertEqual(channel6.code, 200) + self.assertEqual(channel6.json_body.get("status"), "complete") + failed_redactions = channel6.json_body.get("failed_redactions") + self.assertEqual(failed_redactions, {}) + + # double check messages in room 2 were redacted + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel7 = self.make_request( + "GET", + f"rooms/{self.rm2}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel7.code, 200) + + matches = [] + for event in channel7.json_body["chunk"]: + for event_id in originals2: + if event["type"] == "m.room.redaction" and event["redacts"] == event_id: + matches.append((event_id, event)) + # we redacted 6 messages + self.assertEqual(len(matches), 6) + + def test_redactions_for_remote_user_succeed_with_admin_priv_in_room(self) -> None: + """ + Test that if the admin requester has privileges in a room, redaction requests + succeed for a remote user + """ + + # inject some messages from remote user and collect event ids + original_message_ids = [] + for i in range(5): + event = self.get_success( + inject_event( + self.hs, + room_id=self.rm1, + type="m.room.message", + sender="@remote:remote_server", + content={"msgtype": "m.text", "body": f"nefarious_chatter{i}"}, + ) + ) + original_message_ids.append(event.event_id) + + # send a request to redact a remote user's messages in a room. + # the server admin created this room and has admin privilege in room + channel = self.make_request( + "POST", + "/_synapse/admin/v1/user/@remote:remote_server/redact", + content={"rooms": [self.rm1]}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + id = channel.json_body.get("redact_id") + + # check that there were no failed redactions + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/user/redact_status/{id}", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body.get("status"), "complete") + failed_redactions = channel.json_body.get("failed_redactions") + self.assertEqual(failed_redactions, {}) + + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel = self.make_request( + "GET", + f"rooms/{self.rm1}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + for event in channel.json_body["chunk"]: + for event_id in original_message_ids: + if event["type"] == "m.room.redaction" and event["redacts"] == event_id: + original_message_ids.remove(event_id) + break + # we originally sent 5 messages so 5 should be redacted + self.assertEqual(len(original_message_ids), 0) + + def test_redact_redacts_encrypted_messages(self) -> None: + """ + Test that user's encrypted messages are redacted + """ + encrypted_room = self.helper.create_room_as( + self.admin, tok=self.admin_tok, room_version="7" + ) + self.helper.send_state( + encrypted_room, + EventTypes.RoomEncryption, + {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, + tok=self.admin_tok, + ) + # join room send some messages + originals = [] + join = self.helper.join(encrypted_room, self.bad_user, tok=self.bad_user_tok) + originals.append(join["event_id"]) + for _ in range(15): + res = self.helper.send_event( + encrypted_room, "m.room.encrypted", {}, tok=self.bad_user_tok + ) + originals.append(res["event_id"]) + + # redact user's events + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": []}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + matched = [] + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel = self.make_request( + "GET", + f"rooms/{encrypted_room}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + for event in channel.json_body["chunk"]: + for event_id in originals: + if event["type"] == "m.room.redaction" and event["redacts"] == event_id: + matched.append(event_id) + self.assertEqual(len(matched), len(originals)) + + +class UserRedactionBackgroundTaskTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin = self.register_user("thomas", "pass", True) + self.admin_tok = self.login("thomas", "pass") + + self.bad_user = self.register_user("teresa", "pass") + self.bad_user_tok = self.login("teresa", "pass") + + # create rooms - room versions 11+ store the `redacts` key in content while + # earlier ones don't so we use a mix of room versions + self.rm1 = self.helper.create_room_as( + self.admin, tok=self.admin_tok, room_version="7" + ) + self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok) + self.rm3 = self.helper.create_room_as( + self.admin, tok=self.admin_tok, room_version="11" + ) + + @override_config({"run_background_tasks_on": "worker1"}) + def test_redact_messages_all_rooms(self) -> None: + """ + Test that redact task successfully runs when `run_background_tasks_on` is specified + """ + self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + + # join rooms, send some messages + original_event_ids = set() + for rm in [self.rm1, self.rm2, self.rm3]: + join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok) + original_event_ids.add(join["event_id"]) + for i in range(15): + event = {"body": f"hello{i}", "msgtype": "m.text"} + res = self.helper.send_event( + rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200 + ) + original_event_ids.add(res["event_id"]) + + # redact all events in all rooms + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/user/{self.bad_user}/redact", + content={"rooms": []}, + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + id = channel.json_body.get("redact_id") + + timeout_s = 10 + start_time = time.time() + redact_result = "" + while redact_result != "complete": + if start_time + timeout_s < time.time(): + self.fail("Timed out waiting for redactions.") + + channel2 = self.make_request( + "GET", + f"/_synapse/admin/v1/user/redact_status/{id}", + access_token=self.admin_tok, + ) + redact_result = channel2.json_body["status"] + if redact_result == "failed": + self.fail("Redaction task failed.") + + redaction_ids = set() + for rm in [self.rm1, self.rm2, self.rm3]: + filter = json.dumps({"types": [EventTypes.Redaction]}) + channel = self.make_request( + "GET", + f"rooms/{rm}/messages?filter={filter}&limit=50", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + + for event in channel.json_body["chunk"]: + if event["type"] == "m.room.redaction": + redaction_ids.add(event["redacts"]) + + self.assertIncludes(redaction_ids, original_event_ids, exact=True) + + +class GetInvitesFromUserTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin = self.register_user("thomas", "pass", True) + self.admin_tok = self.login("thomas", "pass") + + self.bad_user = self.register_user("teresa", "pass") + self.bad_user_tok = self.login("teresa", "pass") + + self.random_users = [] + for i in range(4): + self.random_users.append(self.register_user(f"user{i}", f"pass{i}")) + + self.room1 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok) + self.room2 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok) + self.room3 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok) + + @unittest.override_config( + {"rc_invites": {"per_issuer": {"per_second": 1000, "burst_count": 1000}}} + ) + def test_get_user_invite_count_new_invites_test_case(self) -> None: + """ + Test that new invites that arrive after a provided timestamp are counted + """ + # grab a current timestamp + before_invites_sent_ts = self.hs.get_clock().time_msec() + + # bad user sends some invites + for room_id in [self.room1, self.room2]: + for user in self.random_users: + self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok) + + # fetch using timestamp, all should be returned + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["invite_count"], 8) + + # send some more invites, they should show up in addition to original 8 using same timestamp + for user in self.random_users: + self.helper.invite( + self.room3, src=self.bad_user, targ=user, tok=self.bad_user_tok + ) + + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["invite_count"], 12) + + def test_get_user_invite_count_invites_before_ts_test_case(self) -> None: + """ + Test that invites sent before provided ts are not counted + """ + # bad user sends some invites + for room_id in [self.room1, self.room2]: + for user in self.random_users: + self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok) + + # add a msec between last invite and ts + after_invites_sent_ts = self.hs.get_clock().time_msec() + 1 + + # fetch invites with timestamp, none should be returned + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={after_invites_sent_ts}", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["invite_count"], 0) + + def test_user_invite_count_kick_ban_not_counted(self) -> None: + """ + Test that kicks and bans are not counted in invite count + """ + to_kick_user_id = self.register_user("kick_me", "pass") + to_kick_tok = self.login("kick_me", "pass") + + self.helper.join(self.room1, to_kick_user_id, tok=to_kick_tok) + + # grab a current timestamp + before_invites_sent_ts = self.hs.get_clock().time_msec() + + # bad user sends some invites (8) + for room_id in [self.room1, self.room2]: + for user in self.random_users: + self.helper.invite( + room_id, src=self.bad_user, targ=user, tok=self.bad_user_tok + ) + + # fetch using timestamp, all invites sent should be counted + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["invite_count"], 8) + + # send a kick and some bans and make sure these aren't counted against invite total + for user in self.random_users: + self.helper.ban( + self.room1, src=self.bad_user, targ=user, tok=self.bad_user_tok + ) + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{self.room1}/kick", + content={"user_id": to_kick_user_id}, + access_token=self.bad_user_tok, + ) + self.assertEqual(channel.code, 200) + + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}", + access_token=self.admin_tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["invite_count"], 8) + + +class GetCumulativeJoinedRoomCountForUserTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin = self.register_user("thomas", "pass", True) + self.admin_tok = self.login("thomas", "pass") + + self.bad_user = self.register_user("teresa", "pass") + self.bad_user_tok = self.login("teresa", "pass") + + def test_user_cumulative_joined_room_count(self) -> None: + """ + Tests proper count returned from /cumulative_joined_room_count endpoint + """ + # Create rooms and join, grab timestamp before room creation + before_room_creation_timestamp = self.hs.get_clock().time_msec() + + joined_rooms = [] + for _ in range(3): + room = self.helper.create_room_as(self.admin, tok=self.admin_tok) + self.helper.join( + room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok + ) + joined_rooms.append(room) + + # get a timestamp after room creation and join, add a msec between last join and ts + after_room_creation = self.hs.get_clock().time_msec() + 1 + + # Get rooms using this timestamp, there should be none since all rooms were created and joined + # before provided timestamp + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(after_room_creation)}", + access_token=self.admin_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["cumulative_joined_room_count"]) + + # fetch rooms with the older timestamp before they were created and joined, this should + # return the rooms + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}", + access_token=self.admin_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + len(joined_rooms), channel.json_body["cumulative_joined_room_count"] + ) + + def test_user_joined_room_count_includes_left_and_banned_rooms(self) -> None: + """ + Tests proper count returned from /joined_room_count endpoint when user has left + or been banned from joined rooms + """ + # Create rooms and join, grab timestamp before room creation + before_room_creation_timestamp = self.hs.get_clock().time_msec() + + joined_rooms = [] + for _ in range(3): + room = self.helper.create_room_as(self.admin, tok=self.admin_tok) + self.helper.join( + room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok + ) + joined_rooms.append(room) + + # fetch rooms with the older timestamp before they were created and joined + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}", + access_token=self.admin_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + len(joined_rooms), channel.json_body["cumulative_joined_room_count"] + ) + + # have the user banned from/leave the joined rooms + self.helper.ban( + joined_rooms[0], + src=self.admin, + targ=self.bad_user, + expect_code=200, + tok=self.admin_tok, + ) + self.helper.change_membership( + joined_rooms[1], + src=self.bad_user, + targ=self.bad_user, + membership="leave", + expect_code=200, + tok=self.bad_user_tok, + ) + self.helper.ban( + joined_rooms[2], + src=self.admin, + targ=self.bad_user, + expect_code=200, + tok=self.admin_tok, + ) + + # fetch the joined room count again, the number should remain the same as the collected joined rooms + channel = self.make_request( + "GET", + f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}", + access_token=self.admin_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual( + len(joined_rooms), channel.json_body["cumulative_joined_room_count"] + )