diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+ """
+ A fake data store which stores a mapping of state key to event content.
+ (I.e. the state key is used as the event ID.)
+ """
+
+ def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+ """
+ Args:
+ events: A state map to event contents.
+ """
+ self._events = {}
+
+ for i, (event_id, content) in enumerate(events):
+ self._events[event_id] = FrozenEvent(
+ {
+ "event_id": "$event_id",
+ "type": event_id[0],
+ "sender": "@user:test",
+ "state_key": event_id[1],
+ "room_id": "#room:test",
+ "content": content,
+ "origin_server_ts": i,
+ },
+ RoomVersions.V1,
+ )
+
+ async def get_event(
+ self, event_id: StateKey, allow_none: bool = False
+ ) -> Optional[FrozenEvent]:
+ assert allow_none, "Mock not configured for allow_none = False"
+
+ return self._events.get(event_id)
+
+ async def get_events(self, event_ids: Iterable[StateKey]):
+ # This is cheating since it just returns all events.
+ return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+ USER_ID = "@test:test"
+ OTHER_USER_ID = "@user:test"
+
+ def _calculate_room_name(
+ self,
+ events: StateMap[dict],
+ user_id: str = "",
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+ ):
+ # This isn't 100% accurate, but works with MockDataStore.
+ room_state_ids = {k[0]: k[0] for k in events}
+
+ return self.get_success(
+ calculate_room_name(
+ MockDataStore(events),
+ room_state_ids,
+ user_id or self.USER_ID,
+ fallback_to_members,
+ fallback_to_single_member,
+ )
+ )
+
+ def test_name(self):
+ """A room name event should be used."""
+ events = [
+ ((EventTypes.Name, ""), {"name": "test-name"}),
+ ]
+ self.assertEqual("test-name", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.Name, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.Name, ""), {"name": 1})]
+ self.assertEqual(1, self._calculate_room_name(events))
+
+ def test_canonical_alias(self):
+ """An canonical alias should be used."""
+ events = [
+ ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+ ]
+ self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+ # Check if the event content has garbage.
+ events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_invite(self):
+ """An invite has special behaviour."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+ ]
+ self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+ # Ensure this logic is skipped if we don't fallback to members.
+ self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+ # Check if the event content has garbage.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ]
+ self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+ # No member event for sender.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+ ]
+ self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+ def test_no_members(self):
+ """Behaviour of an empty room."""
+ events = []
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ # Note that events with invalid (or missing) membership are ignored.
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+ ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+ ]
+ self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+ def test_no_other_members(self):
+ """Behaviour of a room with no other members in it."""
+ events = [
+ (
+ (EventTypes.Member, self.USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Me"},
+ ),
+ ]
+ self.assertEqual("Me", self._calculate_room_name(events))
+
+ # Check if the event content has no displayname.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("@test:test", self._calculate_room_name(events))
+
+ # 3pid invite, use the other user (who is set as the sender).
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual(
+ "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+ )
+
+ events = [
+ ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+ ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+ ]
+ self.assertEqual(
+ "Inviting email address",
+ self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+ )
+
+ def test_one_other_member(self):
+ """Behaviour of a room with a single other member."""
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ]
+ self.assertEqual("Other User", self._calculate_room_name(events))
+ self.assertIsNone(
+ self._calculate_room_name(events, fallback_to_single_member=False)
+ )
+
+ # Check if the event content has no displayname and is an invite.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.INVITE},
+ ),
+ ]
+ self.assertEqual("@user:test", self._calculate_room_name(events))
+
+ def test_other_members(self):
+ """Behaviour of a room with multiple other members."""
+ # Two other members.
+ events = [
+ ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+ (
+ (EventTypes.Member, self.OTHER_USER_ID),
+ {"membership": Membership.JOIN, "displayname": "Other User"},
+ ),
+ ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+ ]
+ self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+ # Three or more other members.
+ events.append(
+ ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+ )
+ self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"type": "m.room.history_visibility",
"sender": "@user:test",
"state_key": "",
- "room_id": "@room:test",
+ "room_id": "#room:test",
"content": content,
},
RoomVersions.V1,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 04599c2fcf..e48f8c1d7b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import make_awaitable
@@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.user1 = self.register_user(
- "user1", "pass1", admin=False, displayname="Name 1"
- )
- self.user2 = self.register_user(
- "user2", "pass2", admin=False, displayname="Name 2"
- )
-
def test_no_auth(self):
"""
Try to list users without authentication.
@@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
+ self._create_users(1)
other_user_token = self.login("user1", "pass1")
channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
List all users, including deactivated users.
"""
+ self._create_users(2)
+
channel = self.make_request(
"GET",
self.url + "?deactivated=true",
@@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, channel.json_body["total"])
# Check that all fields are available
- for u in channel.json_body["users"]:
- self.assertIn("name", u)
- self.assertIn("is_guest", u)
- self.assertIn("admin", u)
- self.assertIn("user_type", u)
- self.assertIn("deactivated", u)
- self.assertIn("displayname", u)
- self.assertIn("avatar_url", u)
+ self._check_fields(channel.json_body["users"])
def test_search_term(self):
"""Test that searching for a users works correctly"""
@@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that users were returned
self.assertTrue("users" in channel.json_body)
+ self._check_fields(channel.json_body["users"])
users = channel.json_body["users"]
# Check that the expected number of users were returned
@@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
u = users[0]
self.assertEqual(expected_user_id, u["name"])
+ self._create_users(2)
+
+ user1 = "@user1:test"
+ user2 = "@user2:test"
+
# Perform search tests
- _search_test(self.user1, "er1")
- _search_test(self.user1, "me 1")
+ _search_test(user1, "er1")
+ _search_test(user1, "me 1")
- _search_test(self.user2, "er2")
- _search_test(self.user2, "me 2")
+ _search_test(user2, "er2")
+ _search_test(user2, "me 2")
- _search_test(self.user1, "er1", "user_id")
- _search_test(self.user2, "er2", "user_id")
+ _search_test(user1, "er1", "user_id")
+ _search_test(user2, "er2", "user_id")
# Test case insensitive
- _search_test(self.user1, "ER1")
- _search_test(self.user1, "NAME 1")
+ _search_test(user1, "ER1")
+ _search_test(user1, "NAME 1")
- _search_test(self.user2, "ER2")
- _search_test(self.user2, "NAME 2")
+ _search_test(user2, "ER2")
+ _search_test(user2, "NAME 2")
- _search_test(self.user1, "ER1", "user_id")
- _search_test(self.user2, "ER2", "user_id")
+ _search_test(user1, "ER1", "user_id")
+ _search_test(user2, "ER2", "user_id")
_search_test(None, "foo")
_search_test(None, "bar")
@@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+
+ # negative limit
+ channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # negative from
+ channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid guests
+ channel = self.make_request(
+ "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid deactivated
+ channel = self.make_request(
+ "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ def test_limit(self):
+ """
+ Testing list of users with limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 5)
+ self.assertEqual(channel.json_body["next_token"], "5")
+ self._check_fields(channel.json_body["users"])
+
+ def test_from(self):
+ """
+ Testing list of users with a defined starting point (from)
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["users"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of users with a defined starting point and limit
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(channel.json_body["next_token"], "15")
+ self.assertEqual(len(channel.json_body["users"]), 10)
+ self._check_fields(channel.json_body["users"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ number_users = 20
+ # Create one less user (since there's already an admin user).
+ self._create_users(number_users - 1)
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), number_users)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 19)
+ self.assertEqual(channel.json_body["next_token"], "19")
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], number_users)
+ self.assertEqual(len(channel.json_body["users"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+ Args:
+ content: List that is checked for content
+ """
+ for u in content:
+ self.assertIn("name", u)
+ self.assertIn("is_guest", u)
+ self.assertIn("admin", u)
+ self.assertIn("user_type", u)
+ self.assertIn("deactivated", u)
+ self.assertIn("displayname", u)
+ self.assertIn("avatar_url", u)
+
+ def _create_users(self, number_users: int):
+ """
+ Create a number of users
+ Args:
+ number_users: Number of users to be created
+ """
+ for i in range(1, number_users + 1):
+ self.register_user(
+ "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+ )
+
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index ae2b32b131..a6c6985173 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config()
config["media_store_path"] = self.media_store_path
- config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
@@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
+ """Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
+ """Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
+ def test_invalid_type(self):
+ """An invalid thumbnail type is never available."""
+ self._test_thumbnail("invalid", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+ )
+ def test_no_thumbnail_crop(self):
+ """
+ Override the config to generate only scaled thumbnails, but request a cropped one.
+ """
+ self._test_thumbnail("crop", None, False)
+
+ @unittest.override_config(
+ {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+ )
+ def test_no_thumbnail_scale(self):
+ """
+ Override the config to generate only cropped thumbnails, but request a scaled one.
+ """
+ self._test_thumbnail("scale", None, False)
+
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
channel = make_request(
|