diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 16bb4349f5..412718f06c 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -21,9 +21,12 @@
import hashlib
import hmac
+import json
import os
+import time
import urllib.parse
from binascii import unhexlify
+from http import HTTPStatus
from typing import Dict, List, Optional
from unittest.mock import AsyncMock, Mock, patch
@@ -33,7 +36,13 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
-from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
+from synapse.api.constants import (
+ ApprovalNoticeMedium,
+ EventContentFields,
+ EventTypes,
+ LoginType,
+ UserTypes,
+)
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
@@ -42,6 +51,7 @@ from synapse.rest.client import (
devices,
login,
logout,
+ media,
profile,
register,
room,
@@ -54,7 +64,9 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import SMALL_PNG
+from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config
@@ -316,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_extra_user_type(self) -> None:
+ """
+ Check that the extra user type can be used when registering a user.
+ """
+
+ def nonce_mac(user_type: str) -> tuple[str, str]:
+ """
+ Get a nonce and the expected HMAC for that nonce.
+ """
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii")
+ + b"\x00alice\x00abc123\x00notadmin\x00"
+ + user_type.encode("ascii")
+ )
+ want_mac_str = want_mac.hexdigest()
+
+ return nonce, want_mac_str
+
+ nonce, mac = nonce_mac("extra1")
+ # Valid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra1",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ nonce, mac = nonce_mac("extra3")
+ # Invalid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra3",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
+
def test_displayname(self) -> None:
"""
Test that displayname of new user is set
@@ -715,7 +782,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
@@ -1174,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase):
not_user_types=["custom"],
)
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_filter_not_user_types_with_extra(self) -> None:
+ """Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
+
+ regular_user_id = self.register_user("normalo", "secret")
+
+ extra1_user_id = self.register_user("extra1", "secret")
+ self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id),
+ {"user_type": "extra1"},
+ access_token=self.admin_user_tok,
+ )
+
+ def test_user_type(
+ expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
+ ) -> None:
+ """Runs a test for the not_user_types param
+ Args:
+ expected_user_ids: Ids of the users that are expected to be returned
+ not_user_types: List of values for the not_user_types param
+ """
+
+ user_type_query = ""
+
+ if not_user_types is not None:
+ user_type_query = "&".join(
+ [f"not_user_type={u}" for u in not_user_types]
+ )
+
+ test_url = f"{self.url}?{user_type_query}"
+ channel = self.make_request(
+ "GET",
+ test_url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code)
+ self.assertEqual(channel.json_body["total"], len(expected_user_ids))
+ self.assertEqual(
+ expected_user_ids,
+ [u["name"] for u in channel.json_body["users"]],
+ )
+
+ # Request without user_types → all users expected
+ test_user_type([self.admin_user, extra1_user_id, regular_user_id])
+
+ # Request and exclude extra1 user type
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1"],
+ )
+
+ # Request and exclude extra1 and extra2 user types
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1", "extra2"],
+ )
+
+ # Request and exclude empty user types → only expected the extra1 user
+ test_user_type([extra1_user_id], not_user_types=[""])
+
+ # Request and exclude an unregistered type → expect all users
+ test_user_type(
+ [self.admin_user, extra1_user_id, regular_user_id],
+ not_user_types=["extra3"],
+ )
+
def test_erasure_status(self) -> None:
# Create a new user.
user_id = self.register_user("eraseme", "eraseme")
@@ -1411,9 +1552,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
)
- self.get_success(
- self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
- )
def test_no_auth(self) -> None:
"""
@@ -1500,7 +1638,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
self.assertFalse(channel.json_body["erased"])
@@ -1525,7 +1662,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
self.assertTrue(channel.json_body["erased"])
@@ -1568,7 +1704,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
@@ -1592,7 +1727,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
@@ -1622,7 +1756,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
@@ -1646,7 +1779,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
@@ -1817,25 +1949,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
- # threepids not valid
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": {"medium": "email", "wrong_address": "id"}},
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": {"address": "value"}},
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
def test_get_user(self) -> None:
"""
Test a simple get of a user.
@@ -1890,8 +2003,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1906,8 +2017,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
@@ -1945,9 +2054,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
@@ -1969,8 +2075,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertFalse(channel.json_body["admin"])
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
@@ -2062,123 +2166,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
- @override_config(
- {
- "email": {
- "enable_notifs": True,
- "notif_for_new_users": True,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_notif_for_new_users(self) -> None:
- """
- Check that a new regular user is created successfully and
- got an email pusher.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- # Note that the given email is not in canonical form.
- "threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-
- pushers = list(
- self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
- )
- self.assertEqual(len(pushers), 1)
- self.assertEqual("@bob:test", pushers[0].user_name)
-
- @override_config(
- {
- "email": {
- "enable_notifs": False,
- "notif_for_new_users": False,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_no_notif_for_new_users(self) -> None:
- """
- Check that a new regular user is created successfully and
- got not an email pusher.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-
- pushers = list(
- self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
- )
- self.assertEqual(len(pushers), 0)
-
- @override_config(
- {
- "email": {
- "enable_notifs": True,
- "notif_for_new_users": True,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None:
- """
- Check that a new regular user is created successfully when they have a msisdn
- threepid and email notif_for_new_users is set to True.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- "threepids": [{"medium": "msisdn", "address": "1234567890"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
-
def test_set_password(self) -> None:
"""
Test setting a new password for another user.
@@ -2222,89 +2209,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
-
- def test_set_threepid(self) -> None:
- """
- Test setting threepid for an other user.
- """
-
- # Add two threepids to user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={
- "threepids": [
- {"medium": "email", "address": "bob1@bob.bob"},
- {"medium": "email", "address": "bob2@bob.bob"},
- ],
- },
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- # result does not always have the same sort order, therefore it becomes sorted
- sorted_result = sorted(
- channel.json_body["threepids"], key=lambda k: k["address"]
- )
- self.assertEqual("email", sorted_result[0]["medium"])
- self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
- self.assertEqual("email", sorted_result[1]["medium"])
- self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
- self._check_fields(channel.json_body)
-
- # Set a new and remove a threepid
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={
- "threepids": [
- {"medium": "email", "address": "bob2@bob.bob"},
- {"medium": "email", "address": "bob3@bob.bob"},
- ],
- },
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
- self._check_fields(channel.json_body)
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_other_user,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
- self._check_fields(channel.json_body)
-
- # Remove threepids
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": []},
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
- self._check_fields(channel.json_body)
-
- def test_set_duplicate_threepid(self) -> None:
"""
Test setting the same threepid for a second user.
First user loses and second user gets mapping of this threepid.
@@ -2328,9 +2232,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"])
self._check_fields(channel.json_body)
# Add threepids to other user
@@ -2347,9 +2248,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
self._check_fields(channel.json_body)
# Add two new threepids to other user
@@ -2369,15 +2267,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# other user has this two threepids
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- # result does not always have the same sort order, therefore it becomes sorted
- sorted_result = sorted(
- channel.json_body["threepids"], key=lambda k: k["address"]
- )
- self.assertEqual("email", sorted_result[0]["medium"])
- self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
- self.assertEqual("email", sorted_result[1]["medium"])
- self.assertEqual("bob3@bob.bob", sorted_result[1]["address"])
self._check_fields(channel.json_body)
# first_user has no threepid anymore
@@ -2388,7 +2277,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
def test_set_external_id(self) -> None:
@@ -2623,9 +2511,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
)
- self.get_success(
- self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
- )
# Get user
channel = self.make_request(
@@ -2637,7 +2522,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2652,7 +2536,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2671,7 +2554,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2965,22 +2847,18 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
- def test_set_user_type(self) -> None:
- """
- Test changing user type.
- """
-
- # Set to support type
+ def set_user_type(self, user_type: Optional[str]) -> None:
+ # Set to user_type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"user_type": UserTypes.SUPPORT},
+ content={"user_type": user_type},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
# Get user
channel = self.make_request(
@@ -2991,30 +2869,44 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
+
+ def test_set_user_type(self) -> None:
+ """
+ Test changing user type.
+ """
+
+ # Set to support type
+ self.set_user_type(UserTypes.SUPPORT)
# Change back to a regular user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"user_type": None},
- )
+ self.set_user_type(None)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ @override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}})
+ def test_set_user_type_with_extras(self) -> None:
+ """
+ Test changing user type with extra_user_types configured.
+ """
- # Get user
+ # Check that we can still set to support type
+ self.set_user_type(UserTypes.SUPPORT)
+
+ # Check that we can set to an extra user type
+ self.set_user_type("extra2")
+
+ # Change back to a regular user
+ self.set_user_type(None)
+
+ # Try setting to invalid type
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"user_type": "extra3"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
def test_accidental_deactivation_prevention(self) -> None:
"""
@@ -3204,7 +3096,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content: Content dictionary to check
"""
self.assertIn("displayname", content)
- self.assertIn("threepids", content)
self.assertIn("avatar_url", content)
self.assertIn("admin", content)
self.assertIn("deactivated", content)
@@ -3217,6 +3108,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("consent_ts", content)
self.assertIn("external_ids", content)
self.assertIn("last_seen_ts", content)
+ self.assertIn("suspended", content)
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", content)
@@ -3513,6 +3405,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ media.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -3692,7 +3585,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "DELETE"])
def test_invalid_parameter(self, method: str) -> None:
"""If parameters are invalid, an error is returned."""
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
method,
self.url + "?order_by=bar",
@@ -3887,9 +3780,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
image_data1 = SMALL_PNG
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
image_data2 = unhexlify(
- b"47494638376101000100800100000000"
- b"ffffff2c00000000010001000002024c"
- b"01003b"
+ b"47494638376101000100800100000000ffffff2c00000000010001000002024c01003b"
)
# Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
image_data3 = unhexlify(
@@ -4019,7 +3910,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts`
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id}",
shorthand=False,
access_token=user_token,
)
@@ -4858,100 +4749,6 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
)
-class UsersByThreePidTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.get_success(
- self.store.user_add_threepid(
- self.other_user, "email", "user@email.com", 1, 1
- )
- )
- self.get_success(
- self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1)
- )
-
- def test_no_auth(self) -> None:
- """Try to look up a user without authentication."""
- url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
-
- channel = self.make_request(
- "GET",
- url,
- )
-
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- def test_medium_does_not_exist(self) -> None:
- """Tests that both a lookup for a medium that does not exist and a user that
- doesn't exist with that third party ID returns a 404"""
- # test for unknown medium
- url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- # test for unknown user with a known medium
- url = "/_synapse/admin/v1/threepid/email/users/unknown"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_success(self) -> None:
- """Tests a successful medium + address lookup"""
- # test for email medium with encoded value of user@email.com
- url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(
- {"user_id": self.other_user},
- channel.json_body,
- )
-
- # test for msidn medium with encoded value of +1-12345678
- url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(
- {"user_id": self.other_user},
- channel.json_body,
- )
-
-
class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -5024,7 +4821,6 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
- @override_config({"experimental_features": {"msc3823_account_suspension": True}})
def test_suspend_user(self) -> None:
# test that suspending user works
channel = self.make_request(
@@ -5089,3 +4885,766 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase):
res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
self.assertEqual(True, res5)
+
+
+class UserRedactionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ self.store = hs.get_datastores().main
+
+ self.spam_checker = hs.get_module_api_callbacks().spam_checker
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that request to redact events in all rooms user is member of is successful
+ """
+ # join rooms, send some messages
+ originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ originals.append(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matched = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matched.append(event_id)
+ self.assertEqual(len(matched), len(originals))
+
+ def test_redact_messages_specific_rooms(self) -> None:
+ """
+ Test that request to redact events in specified rooms user is member of is successful
+ """
+
+ originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ originals.append(res["event_id"])
+
+ # redact messages in rooms 1 and 3
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # messages in requested rooms are redacted
+ for rm in [self.rm1, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matches = []
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matches.append((event_id, event))
+ # we redacted 16 messages
+ self.assertEqual(len(matches), 16)
+
+ channel = self.make_request(
+ "GET", f"rooms/{self.rm2}/messages?limit=50", access_token=self.admin_tok
+ )
+ self.assertEqual(channel.code, 200)
+
+ # messages in remaining room are not
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ self.fail("found redaction in room 2")
+
+ def test_redact_status(self) -> None:
+ rm2_originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ if rm == self.rm2:
+ rm2_originals.append(join["event_id"])
+ for i in range(5):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ if rm == self.rm2:
+ rm2_originals.append(res["event_id"])
+
+ # redact messages in rooms 1 and 3
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel2.code, 200)
+ self.assertEqual(channel2.json_body.get("status"), "complete")
+ self.assertEqual(channel2.json_body.get("failed_redactions"), {})
+
+ # mock that will cause persisting the redaction events to fail
+ async def check_event_for_spam(event: str) -> str:
+ return "spam"
+
+ self.spam_checker.check_event_for_spam = check_event_for_spam # type: ignore
+
+ channel3 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm2]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel3.json_body.get("redact_id")
+
+ channel4 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel4.code, 200)
+ self.assertEqual(channel4.json_body.get("status"), "complete")
+ failed_redactions = channel4.json_body.get("failed_redactions")
+ assert failed_redactions is not None
+ matched = []
+ for original in rm2_originals:
+ if failed_redactions.get(original) is not None:
+ matched.append(original)
+ self.assertEqual(len(matched), len(rm2_originals))
+
+ def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
+ originals1 = []
+ originals2 = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ if rm in [self.rm1, self.rm3]:
+ originals1.append(join["event_id"])
+ else:
+ originals2.append(join["event_id"])
+ for i in range(5):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ if rm in [self.rm1, self.rm3]:
+ originals1.append(res["event_id"])
+ else:
+ originals2.append(res["event_id"])
+
+ # kick user from rooms 1 and 3
+ for r in [self.rm1, self.rm3]:
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{r}/kick",
+ content={"reason": "being a bummer", "user_id": self.bad_user},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # redact messages in room 1 and 3
+ channel1 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel1.code, 200)
+ id = channel1.json_body.get("redact_id")
+
+ # check that there were no failed redactions in room 1 and 3
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel2.code, 200)
+ self.assertEqual(channel2.json_body.get("status"), "complete")
+ failed_redactions = channel2.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ # double check
+ for rm in [self.rm1, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel3 = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel3.code, 200)
+
+ matches = []
+ for event in channel3.json_body["chunk"]:
+ for event_id in originals1:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matches.append((event_id, event))
+ # we redacted 6 messages
+ self.assertEqual(len(matches), 6)
+
+ # ban user from room 2
+ channel4 = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{self.rm2}/ban",
+ content={"reason": "being a bummer", "user_id": self.bad_user},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result)
+
+ # make a request to ban all user's messages
+ channel5 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel5.code, 200)
+ id2 = channel5.json_body.get("redact_id")
+
+ # check that there were no failed redactions in room 2
+ channel6 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id2}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel6.code, 200)
+ self.assertEqual(channel6.json_body.get("status"), "complete")
+ failed_redactions = channel6.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ # double check messages in room 2 were redacted
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel7 = self.make_request(
+ "GET",
+ f"rooms/{self.rm2}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel7.code, 200)
+
+ matches = []
+ for event in channel7.json_body["chunk"]:
+ for event_id in originals2:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ matches.append((event_id, event))
+ # we redacted 6 messages
+ self.assertEqual(len(matches), 6)
+
+ def test_redactions_for_remote_user_succeed_with_admin_priv_in_room(self) -> None:
+ """
+ Test that if the admin requester has privileges in a room, redaction requests
+ succeed for a remote user
+ """
+
+ # inject some messages from remote user and collect event ids
+ original_message_ids = []
+ for i in range(5):
+ event = self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.rm1,
+ type="m.room.message",
+ sender="@remote:remote_server",
+ content={"msgtype": "m.text", "body": f"nefarious_chatter{i}"},
+ )
+ )
+ original_message_ids.append(event.event_id)
+
+ # send a request to redact a remote user's messages in a room.
+ # the server admin created this room and has admin privilege in room
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/user/@remote:remote_server/redact",
+ content={"rooms": [self.rm1]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ # check that there were no failed redactions
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body.get("status"), "complete")
+ failed_redactions = channel.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{self.rm1}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in original_message_ids:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ original_message_ids.remove(event_id)
+ break
+ # we originally sent 5 messages so 5 should be redacted
+ self.assertEqual(len(original_message_ids), 0)
+
+ def test_redact_redacts_encrypted_messages(self) -> None:
+ """
+ Test that user's encrypted messages are redacted
+ """
+ encrypted_room = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.helper.send_state(
+ encrypted_room,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=self.admin_tok,
+ )
+ # join room send some messages
+ originals = []
+ join = self.helper.join(encrypted_room, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for _ in range(15):
+ res = self.helper.send_event(
+ encrypted_room, "m.room.encrypted", {}, tok=self.bad_user_tok
+ )
+ originals.append(res["event_id"])
+
+ # redact user's events
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matched = []
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{encrypted_room}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ matched.append(event_id)
+ self.assertEqual(len(matched), len(originals))
+
+
+class UserRedactionBackgroundTaskTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ @override_config({"run_background_tasks_on": "worker1"})
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that redact task successfully runs when `run_background_tasks_on` is specified
+ """
+ self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ # join rooms, send some messages
+ original_event_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ original_event_ids.add(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ original_event_ids.add(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ timeout_s = 10
+ start_time = time.time()
+ redact_result = ""
+ while redact_result != "complete":
+ if start_time + timeout_s < time.time():
+ self.fail("Timed out waiting for redactions.")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ redact_result = channel2.json_body["status"]
+ if redact_result == "failed":
+ self.fail("Redaction task failed.")
+
+ redaction_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ redaction_ids.add(event["redacts"])
+
+ self.assertIncludes(redaction_ids, original_event_ids, exact=True)
+
+
+class GetInvitesFromUserTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ self.random_users = []
+ for i in range(4):
+ self.random_users.append(self.register_user(f"user{i}", f"pass{i}"))
+
+ self.room1 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+ self.room2 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+ self.room3 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_issuer": {"per_second": 1000, "burst_count": 1000}}}
+ )
+ def test_get_user_invite_count_new_invites_test_case(self) -> None:
+ """
+ Test that new invites that arrive after a provided timestamp are counted
+ """
+ # grab a current timestamp
+ before_invites_sent_ts = self.hs.get_clock().time_msec()
+
+ # bad user sends some invites
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok)
+
+ # fetch using timestamp, all should be returned
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+ # send some more invites, they should show up in addition to original 8 using same timestamp
+ for user in self.random_users:
+ self.helper.invite(
+ self.room3, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 12)
+
+ def test_get_user_invite_count_invites_before_ts_test_case(self) -> None:
+ """
+ Test that invites sent before provided ts are not counted
+ """
+ # bad user sends some invites
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok)
+
+ # add a msec between last invite and ts
+ after_invites_sent_ts = self.hs.get_clock().time_msec() + 1
+
+ # fetch invites with timestamp, none should be returned
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={after_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 0)
+
+ def test_user_invite_count_kick_ban_not_counted(self) -> None:
+ """
+ Test that kicks and bans are not counted in invite count
+ """
+ to_kick_user_id = self.register_user("kick_me", "pass")
+ to_kick_tok = self.login("kick_me", "pass")
+
+ self.helper.join(self.room1, to_kick_user_id, tok=to_kick_tok)
+
+ # grab a current timestamp
+ before_invites_sent_ts = self.hs.get_clock().time_msec()
+
+ # bad user sends some invites (8)
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(
+ room_id, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ # fetch using timestamp, all invites sent should be counted
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+ # send a kick and some bans and make sure these aren't counted against invite total
+ for user in self.random_users:
+ self.helper.ban(
+ self.room1, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/kick",
+ content={"user_id": to_kick_user_id},
+ access_token=self.bad_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+
+class GetCumulativeJoinedRoomCountForUserTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ def test_user_cumulative_joined_room_count(self) -> None:
+ """
+ Tests proper count returned from /cumulative_joined_room_count endpoint
+ """
+ # Create rooms and join, grab timestamp before room creation
+ before_room_creation_timestamp = self.hs.get_clock().time_msec()
+
+ joined_rooms = []
+ for _ in range(3):
+ room = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.helper.join(
+ room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok
+ )
+ joined_rooms.append(room)
+
+ # get a timestamp after room creation and join, add a msec between last join and ts
+ after_room_creation = self.hs.get_clock().time_msec() + 1
+
+ # Get rooms using this timestamp, there should be none since all rooms were created and joined
+ # before provided timestamp
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(after_room_creation)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["cumulative_joined_room_count"])
+
+ # fetch rooms with the older timestamp before they were created and joined, this should
+ # return the rooms
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
+
+ def test_user_joined_room_count_includes_left_and_banned_rooms(self) -> None:
+ """
+ Tests proper count returned from /joined_room_count endpoint when user has left
+ or been banned from joined rooms
+ """
+ # Create rooms and join, grab timestamp before room creation
+ before_room_creation_timestamp = self.hs.get_clock().time_msec()
+
+ joined_rooms = []
+ for _ in range(3):
+ room = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.helper.join(
+ room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok
+ )
+ joined_rooms.append(room)
+
+ # fetch rooms with the older timestamp before they were created and joined
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
+
+ # have the user banned from/leave the joined rooms
+ self.helper.ban(
+ joined_rooms[0],
+ src=self.admin,
+ targ=self.bad_user,
+ expect_code=200,
+ tok=self.admin_tok,
+ )
+ self.helper.change_membership(
+ joined_rooms[1],
+ src=self.bad_user,
+ targ=self.bad_user,
+ membership="leave",
+ expect_code=200,
+ tok=self.bad_user_tok,
+ )
+ self.helper.ban(
+ joined_rooms[2],
+ src=self.admin,
+ targ=self.bad_user,
+ expect_code=200,
+ tok=self.admin_tok,
+ )
+
+ # fetch the joined room count again, the number should remain the same as the collected joined rooms
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
|