diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index fb36aa9940..6cf56b1e35 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ self.updater = BackgroundUpdater(hs, self.store.db_pool)
@parameterized.expand(
[
@@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"""Test the status API works with a background update."""
# Create a new background update
-
self._register_bg_update()
self.store.db_pool.updates.start_doing_background_updates()
+
self.reactor.pump([1.0, 1.0, 1.0])
channel = self.make_request(
@@ -155,10 +156,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ self.updater.default_background_batch_size
),
}
},
@@ -210,10 +211,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ self.updater.default_background_batch_size
),
}
},
@@ -239,10 +240,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.1,
"total_duration_ms": 1000.0,
"total_item_count": (
- BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ self.updater.default_background_batch_size
),
}
},
@@ -278,11 +279,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.001,
+ "average_items_per_ms": 0.05263157894736842,
"total_duration_ms": 2000.0,
- "total_item_count": (
- 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
- ),
+ "total_item_count": (110),
}
},
"enabled": True,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index a60ea0a563..bef911d5df 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", True)
+ @override_config({"max_avatar_size": 1234})
+ def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
+ """Check we can erase a user whose avatar is the empty string.
+
+ Reproduces #12257.
+ """
+ # Patch `self.other_user` to have an empty string as their avatar.
+ self.get_success(self.store.set_profile_avatar_url("user", ""))
+
+ # Check we can still erase them.
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"erase": True},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self._is_erased("@user:test", True)
+
def test_deactivate_user_erase_false(self) -> None:
"""
Test deactivating a user and set `erase` to `false`
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]],
)
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_no_account_validity(self) -> None:
+ """Tests that if we decide to include account validity in the response but no
+ account validity 'is_user_expired' callback is provided, we default to marking all
+ users as not expired.
+ """
+ user = self.register_user("someuser", "password")
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_account_validity_expired(self) -> None:
+ """Test that if we decide to include account validity in the response and the user
+ is expired, we return the correct info.
+ """
+ user = self.register_user("someuser", "password")
+
+ async def is_expired(user_id: str) -> bool:
+ # We can't blindly say everyone is expired, otherwise the request to get the
+ # account status will fail.
+ return UserID.from_string(user_id).localpart == "someuser"
+
+ self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+ is_expired
+ )
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": True,
+ },
+ },
+ expected_failures=[],
+ )
+
def _test_status(
self,
users: Optional[List[str]],
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 3818b7b14b..7b7d283bb6 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -14,7 +14,7 @@
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.rest.client import login, room, shared_rooms
+from synapse.rest.client import login, mutual_rooms, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -22,16 +22,16 @@ from tests import unittest
from tests.server import FakeChannel
-class UserSharedRoomsTest(unittest.HomeserverTestCase):
+class UserMutualRoomsTest(unittest.HomeserverTestCase):
"""
- Tests the UserSharedRoomsServlet.
+ Tests the UserMutualRoomsServlet.
"""
servlets = [
login.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
- shared_rooms.register_servlets,
+ mutual_rooms.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
+ def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
- "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
% other_user,
access_token=token,
)
@@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
A room should show up in the shared list of rooms between two users
if it is public.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is private.
"""
- self._check_shared_rooms_with(
+ self._check_mutual_rooms_with(
room_one_is_public=False, room_two_is_public=False
)
@@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
The shared room list between two users should contain both public and private
rooms.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
- def _check_shared_rooms_with(
+ def _check_mutual_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
) -> None:
"""Checks that shared public or private rooms between two users appear in
@@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
# Check shared rooms from user1's perspective.
# We should see the one room in common
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room_id_one)
@@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms.
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 2)
for room_id_id in channel.json_body["joined"]:
@@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room)
@@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
# Check user2's view of shared rooms with user1
- channel = self._get_shared_rooms(u2_token, u1)
+ channel = self._get_mutual_rooms(u2_token, u1)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 709f851a38..fe97a0b3dd 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,17 +15,16 @@
import itertools
import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.server import HomeServer
-from synapse.storage.relations import RelationPaginationToken
-from synapse.types import JsonDict, StreamToken
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -80,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
+ expected_response_code: int = 200,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
@@ -116,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content,
access_token=access_token,
)
+ self.assertEqual(expected_response_code, channel.code, channel.json_body)
return channel
+ def _get_related_events(self) -> List[str]:
+ """
+ Requests /relations on the parent ID and returns a list of event IDs.
+ """
+ # Request the relations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+ def _get_bundled_aggregations(self) -> JsonDict:
+ """
+ Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+ """
+ # Fetch the bundled aggregations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return channel.json_body["unsigned"].get("m.relations", {})
+
+ def _get_aggregations(self) -> List[JsonDict]:
+ """Request /aggregations on the parent ID and includes the returned chunk."""
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ return channel.json_body["chunk"]
+
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
"""Tests that sending a relation works."""
-
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
-
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -152,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_deny_invalid_event(self) -> None:
"""Test that we deny relations on non-existant events"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
# Unless that event is referenced from another event!
self.get_success(
@@ -172,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
desc="test_deny_invalid_event",
)
)
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
def test_deny_invalid_room(self) -> None:
"""Test that we deny relations on non-existant events"""
@@ -188,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
parent_id = res["event_id"]
# Attempt to send an annotation to that event.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ parent_id=parent_id,
+ key="A",
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
def test_deny_double_react(self) -> None:
"""Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(400, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+ )
def test_deny_forked_thread(self) -> None:
"""It is invalid to start a thread off a thread."""
@@ -209,386 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
- )
- self.assertEqual(400, channel.code, channel.json_body)
-
- def test_basic_paginate_relations(self) -> None:
- """Tests that calling pagination API correctly the latest relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- first_annotation_id = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
- second_annotation_id = channel.json_body["event_id"]
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the latest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": second_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
- # Make sure next_batch has something in it that looks like it could be a
- # valid token.
- self.assertIsInstance(
- channel.json_body.get("next_batch"), str, channel.json_body
- )
-
- # Request the relations again, but with a different direction.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the earliest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": first_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- def _stream_token_to_relation_token(self, token: str) -> str:
- """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
- room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
- return self.get_success(
- RelationPaginationToken(
- topological=room_key.topological, stream=room_key.stream
- ).to_string(self.store)
- )
-
- def test_repeated_paginate_relations(self) -> None:
- """Test that if we paginate using a limit and tokens then we get the
- expected events.
- """
-
- expected_event_ids = []
- for idx in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- prev_token = ""
- found_event_ids: List[str] = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- def test_pagination_from_sync_and_messages(self) -> None:
- """Pagination tokens from /sync and /messages can be used to paginate /relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
- # Send an event after the relation events.
- self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
- # Request /sync, limiting it such that only the latest event is returned
- # (and not the relation).
- filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- sync_prev_batch = room_timeline["prev_batch"]
- self.assertIsNotNone(sync_prev_batch)
- # Ensure the relation event is not in the batch returned from /sync.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
- )
-
- # Request /messages, limiting it such that only the latest event is
- # returned (and not the relation).
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b&limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- messages_end = channel.json_body["end"]
- self.assertIsNotNone(messages_end)
- # Ensure the relation event is not in the chunk returned from /messages.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ expected_response_code=400,
)
- # Request /relations with the pagination tokens received from both the
- # /sync and /messages responses above, in turn.
- #
- # This is a tiny bit silly since the client wouldn't know the parent ID
- # from the requests above; consider the parent ID to be known from a
- # previous /sync.
- for from_token in (sync_prev_batch, messages_end):
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # The relation should be in the returned chunk.
- self.assertIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="👍",
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("👍".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
channel = self.make_request(
"GET",
@@ -618,220 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
- @unittest.override_config(
- {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
- )
- def test_bundled_aggregations(self) -> None:
- """
- Test that annotations, references, and threads get correctly bundled.
-
- Note that this doesn't test against /relations since only thread relations
- get bundled via that API. See test_aggregation_get_event_for_thread.
-
- See test_edit for a similar test for edits.
- """
- # Setup by sending a variety of relations.
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_1 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_2 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_2 = channel.json_body["event_id"]
-
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- relations_dict.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- relations_dict[RelationTypes.ANNOTATION],
- )
-
- self.assertEqual(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- relations_dict[RelationTypes.REFERENCE],
- )
-
- self.assertEqual(
- 2,
- relations_dict[RelationTypes.THREAD].get("count"),
- )
- self.assertTrue(
- relations_dict[RelationTypes.THREAD].get("current_user_participated")
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "room_id": self.room,
- "sender": self.user_id,
- "type": "m.room.test",
- "user_id": self.user_id,
- },
- relations_dict[RelationTypes.THREAD].get("latest_event"),
- )
-
- # Request the event directly.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body)
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
- # Request search.
- channel = self.make_request(
- "POST",
- "/search",
- # Search term matches the parent message.
- content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- chunk = [
- result["result"]
- for result in channel.json_body["search_categories"]["room_events"][
- "results"
- ]
- ]
- assert_bundle(self._find_event_in_chunk(chunk))
-
- def test_aggregation_get_event_for_annotation(self) -> None:
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
- def test_aggregation_get_event_for_thread(self) -> None:
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEqual(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -953,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
@@ -1030,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
shouldn't be allowed, are correctly handled.
"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1039,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
@@ -1047,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message.WRONG_TYPE",
content={
@@ -1060,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -1091,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
reply = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1101,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=reply,
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1138,7 +598,6 @@ class RelationsTestCase(BaseRelationsTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self) -> None:
"""Test that editing a thread works."""
@@ -1148,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
@@ -1190,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": new_body,
},
)
- self.assertEqual(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1204,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
},
parent_id=edit_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Request the original event.
channel = self.make_request(
@@ -1231,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1272,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- return event
-
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
@@ -1345,8 +786,638 @@ class RelationsTestCase(BaseRelationsTestCase):
)
+class RelationPaginationTestCase(BaseRelationsTestCase):
+ def test_basic_paginate_relations(self) -> None:
+ """Tests that calling pagination API correctly the latest relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ first_annotation_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ second_annotation_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the latest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": second_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEqual(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), str, channel.json_body
+ )
+
+ # Request the relations again, but with a different direction.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations"
+ f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the earliest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": first_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ def test_repeated_paginate_relations(self) -> None:
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for idx in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+ def test_pagination_from_sync_and_messages(self) -> None:
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+ )
+
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ def test_aggregation_pagination_groups(self) -> None:
+ """Test that we can paginate annotation groups correctly."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self) -> None:
+ """Test that we can paginate within an annotation group."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="👍",
+ access_token=access_tokens[idx],
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ encoded_key = urllib.parse.quote_plus("👍".encode())
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+ """
+ See RelationsTestCase.test_edit for a similar test for edits.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+ """
+
+ def _test_bundled_aggregations(
+ self,
+ relation_type: str,
+ assertion_callable: Callable[[JsonDict], None],
+ expected_db_txn_for_event: int,
+ ) -> None:
+ """
+ Makes requests to various endpoints which should include bundled aggregations
+ and then calls an assertion function on the bundled aggregations.
+
+ Args:
+ relation_type: The field to search for in the `m.relations` field in unsigned.
+ assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+ for relation-specific assertions.
+ expected_db_txn_for_event: The number of database transactions which
+ are expected for a call to /event/.
+ """
+
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+
+ # Ensure the fields are as expected.
+ self.assertCountEqual(relations_dict.keys(), (relation_type,))
+ assertion_callable(relations_dict[relation_type])
+
+ # Request the event directly.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync.
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_annotation(self) -> None:
+ """
+ Test that annotations get correctly bundled.
+ """
+ # Setup by sending a variety of relations.
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_reference(self) -> None:
+ """
+ Test that references get correctly bundled.
+ """
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_thread(self) -> None:
+ """
+ Test that threads get correctly bundled.
+ """
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertTrue(bundled_aggregations.get("current_user_participated"))
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ },
+ bundled_aggregations.get("latest_event"),
+ )
+
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+ def test_aggregation_get_event_for_annotation(self) -> None:
+ """Test that annotations do not get bundled aggregations included
+ when directly requested.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ annotation_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{annotation_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+ def test_aggregation_get_event_for_thread(self) -> None:
+ """Test that threads get bundled aggregations included when directly requested."""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{thread_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ # It should also be included when the entire thread is requested.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+ thread_message = channel.json_body["chunk"][0]
+ self.assertEqual(
+ thread_message["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ def test_bundled_aggregations_with_filter(self) -> None:
+ """
+ If "unsigned" is an omitted field (due to filtering), adding the bundled
+ aggregations should not break.
+
+ Note that the spec allows for a server to return additional fields beyond
+ what is specified.
+ """
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+ # Note that the sync filter does not include "unsigned" as a field.
+ filter = urllib.parse.quote_plus(
+ b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # Ensure the timeline is limited, find the parent event.
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+ # Ensure there's bundled aggregations on it.
+ self.assertIn("unsigned", parent_event)
+ self.assertIn("m.relations", parent_event["unsigned"])
+
+
+class RelationIgnoredUserTestCase(BaseRelationsTestCase):
+ """Relations sent from an ignored user should be ignored."""
+
+ def _test_ignored_user(
+ self, allowed_event_ids: List[str], ignored_event_ids: List[str]
+ ) -> None:
+ """
+ Fetch the relations and ensure they're all there, then ignore user2, and
+ repeat.
+ """
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
+
+ # Ignore user2 and re-do the requests.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids)
+
+ def test_annotation(self) -> None:
+ """Annotations should ignore"""
+ # Send 2 from us, 2 from the to be ignored user.
+ allowed_event_ids = []
+ ignored_event_ids = []
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="a",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="c",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_reference(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_thread(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+
class RelationRedactionTestCase(BaseRelationsTestCase):
- """Test the behaviour of relations when the parent or child event is redacted."""
+ """
+ Test the behaviour of relations when the parent or child event is redacted.
+
+ The behaviour of each relation type is subtly different which causes the tests
+ to be a bit repetitive, they follow a naming scheme of:
+
+ test_redact_(relation|parent)_{relation_type}
+
+ The first bit of "relation" means that the event with the relation defined
+ on it (the child event) is to be redacted. A "parent" means that the target
+ of the relation (the parent event) is to be redacted.
+
+ The relation_type describes which type of relation is under test (i.e. it is
+ related to the value of rel_type in the event content).
+ """
def _redact(self, event_id: str) -> None:
channel = self.make_request(
@@ -1358,40 +1429,116 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
def test_redact_relation_annotation(self) -> None:
- """Test that annotations of an event are properly handled after the
+ """
+ Test that annotations of an event are properly handled after the
annotation is redacted.
+
+ The redacted relation should not be included in bundled aggregations or
+ the response to relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Both relations should exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+ )
+
+ # Both relations appear in the aggregation.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
# Redact one of the reactions.
self._redact(to_redact_event_id)
- # Ensure that the aggregations are correct.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
- self.assertEqual(200, channel.code, channel.json_body)
+ # The unredacted aggregation should still exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+ def test_redact_relation_thread(self) -> None:
+ """
+ Test that thread replies are properly handled after the thread reply redacted.
+
+ The redacted event should not be included in bundled aggregations or
+ the response to relations.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Note that the *last* event in the thread is redacted, as that gets
+ # included in the bundled aggregation.
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 2", "msgtype": "m.text"},
+ )
+ to_redact_event_id = channel.json_body["event_id"]
+
+ # Both relations exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 2,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event returned is the event that will be redacted.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ to_redact_event_id,
+ )
+
+ # Redact one of the reactions.
+ self._redact(to_redact_event_id)
+
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event is now the unredacted event.
self.assertEqual(
- channel.json_body,
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ unredacted_event_id,
)
- def test_redact_relation_edit(self) -> None:
+ def test_redact_parent_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
# Add a relation
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=self.parent_id,
@@ -1401,54 +1548,83 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}/m.replace/m.room.message",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.REPLACE, relations)
# Redact the original event
self._redact(self.parent_id)
- # Try to check for remaining m.replace relations
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}/m.replace/m.room.message",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
+ # The relations are not returned.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 0)
+ self.assertEqual(relations, {})
- # Check that no relations are returned
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
- def test_redact_parent(self) -> None:
- """Test that annotations of an event are redacted when the original event
+ def test_redact_parent_annotation(self) -> None:
+ """Test that annotations of an event are viewable when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
+ related_event_id = channel.json_body["event_id"]
+
+ # The relations should exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.ANNOTATION, relations)
+
+ # The aggregation should exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
# Redact the original event.
self._redact(self.parent_id)
- # Check that aggregations returns zero
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
- access_token=self.user_token,
+ # The relations are returned.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [related_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
+ # There's nothing to aggregate.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_redact_parent_thread(self) -> None:
+ """
+ Test that thread replies are still available when the root event is redacted.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ related_event_id = channel.json_body["event_id"]
+
+ # Redact one of the reactions.
+ self._redact(self.parent_id)
+
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(len(event_ids), 1)
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ related_event_id,
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f3bf8d0934..7b8fe6d025 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -24,6 +24,7 @@ from synapse.util import Clock
from synapse.visibility import filter_events_for_client
from tests import unittest
+from tests.unittest import override_config
one_hour_ms = 3600000
one_day_ms = one_hour_ms * 24
@@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["retention"] = {
+
+ # merge this default retention config with anything that was specified in
+ # @override_config
+ retention_config = {
"enabled": True,
"default_policy": {
"min_lifetime": one_day_ms,
@@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase):
"allowed_lifetime_min": one_day_ms,
"allowed_lifetime_max": one_day_ms * 3,
}
+ retention_config.update(config.get("retention", {}))
+ config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
@@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 2)
+ @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
def test_visibility(self) -> None:
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
- outdated events
+ outdated events, even if the purge job hasn't got to them yet.
+
+ We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
storage = self.hs.get_storage()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
- events = []
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
-
- # Get the event from the store so that we end up with a FrozenEvent that we can
- # give to filter_events_for_client. We need to do this now because the event won't
- # be in the database anymore after it has expired.
- events.append(self.get_success(store.get_event(resp.get("event_id"))))
+ first_event_id = resp.get("event_id")
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
-
valid_event_id = resp.get("event_id")
- events.append(self.get_success(store.get_event(valid_event_id)))
-
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
- # Run filter_events_for_client with our list of FrozenEvents.
+ # Fetch the events, and run filter_events_for_client on them
+ events = self.get_success(
+ store.get_events_as_list([first_event_id, valid_event_id])
+ )
+ self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 37866ee330..3a9617d6da 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
- filter = {"io.element.relation_senders": [self.second_user_id]}
+ filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which third user reacted to.
- filter = {"io.element.relation_senders": [self.third_user_id]}
+ filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which either user reacted to.
- filter = {
- "io.element.relation_senders": [self.second_user_id, self.third_user_id]
- }
+ filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
@@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_type(self) -> None:
# Messages which have annotations.
- filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which have references.
- filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which have either annotations or references.
filter = {
- "io.element.relation_types": [
+ "related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
@@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
- "io.element.relation_senders": [self.second_user_id],
- "io.element.relation_types": [RelationTypes.ANNOTATION],
+ "related_by_senders": [self.second_user_id],
+ "related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 58f1ea11b7..e7de67e3a3 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(args[0], user_id)
self.assertFalse(args[1])
self.assertTrue(args[2])
+
+ def test_check_can_deactivate_user(self) -> None:
+ """Tests that the on_user_deactivation_status_changed module callback is called
+ correctly when processing a user's deactivation.
+ """
+ # Register a mocked callback.
+ deactivation_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_deactivate_user_callbacks.append(
+ deactivation_mock,
+ )
+
+ # Register a user that we'll deactivate.
+ user_id = self.register_user("altan", "password")
+ tok = self.login("altan", "password")
+
+ # Deactivate that user.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/account/deactivate",
+ {
+ "auth": {
+ "type": LoginType.PASSWORD,
+ "password": "password",
+ "identifier": {
+ "type": "m.id.user",
+ "user": user_id,
+ },
+ },
+ "erase": True,
+ },
+ access_token=tok,
+ )
+
+ # Check that the deactivation was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ deactivation_mock.assert_called_once()
+ args = deactivation_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], user_id)
+
+ # Check that the request was not made by an admin
+ self.assertEqual(args[1], False)
+
+ def test_check_can_deactivate_user_admin(self) -> None:
+ """Tests that the on_user_deactivation_status_changed module callback is called
+ correctly when processing a user's deactivation triggered by a server admin.
+ """
+ # Register a mocked callback.
+ deactivation_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_deactivate_user_callbacks.append(
+ deactivation_mock,
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Register a user that we'll deactivate.
+ user_id = self.register_user("altan", "password")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {"deactivated": True},
+ access_token=admin_tok,
+ )
+
+ # Check that the deactivation was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ deactivation_mock.assert_called_once()
+ args = deactivation_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], user_id)
+
+ # Check that the mock was made by an admin
+ self.assertEqual(args[1], True)
+
+ def test_check_can_shutdown_room(self) -> None:
+ """Tests that the check_can_shutdown_room module callback is called
+ correctly when processing an admin's shutdown room request.
+ """
+ # Register a mocked callback.
+ shutdown_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_shutdown_room_callbacks.append(
+ shutdown_mock,
+ )
+
+ # Register an admin user.
+ admin_user_id = self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Shutdown the room.
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v2/rooms/%s" % self.room_id,
+ {},
+ access_token=admin_tok,
+ )
+
+ # Check that the shutdown was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ shutdown_mock.assert_called_once()
+ args = shutdown_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], admin_user_id)
+
+ # Check that the mock was called with the right room ID
+ self.assertEqual(args[1], self.room_id)
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3b5747cb12..8d8251b2ac 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,3 +1,18 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
from unittest.mock import Mock, call
from twisted.internet import defer, reactor
@@ -11,14 +26,14 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
- self.mock_http_response = (200, "GOOD JOB!")
+ self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 4672a68596..978c252f84 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -13,19 +13,24 @@
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
+from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
import signedjson.key
from canonicaljson import encode_canonical_json
-from nacl.signing import SigningKey
from signedjson.sign import sign_json
+from signedjson.types import SigningKey
-from twisted.web.resource import NoResource
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
+from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
@@ -35,11 +40,11 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
return self.setup_test_homeserver(federation_http_client=self.http_client)
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
- async def get_json(destination, path, ignore_backoff=False, **kwargs):
+ async def get_json(
+ destination: str,
+ path: str,
+ ignore_backoff: bool = False,
+ **kwargs: Any,
+ ) -> Union[JsonDict, list]:
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel, self.site)
+ # channel is a `FakeChannel` but `HTTPChannel` is expected
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
resp = channel.json_body
return resp
- def test_get_key(self):
+ def test_get_key(self) -> None:
"""Fetch a remote key"""
SERVER_NAME = "remote.server"
testkey = signedjson.key.generate_signing_key("ver1")
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
self.assertIn(SERVER_NAME, keys[0]["signatures"])
self.assertIn(self.hs.hostname, keys[0]["signatures"])
- def test_get_own_key(self):
+ def test_get_own_key(self) -> None:
"""Fetch our own key"""
testkey = signedjson.key.generate_signing_key("ver1")
self.expect_outgoing_key_request(self.hs.hostname, testkey)
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible.
"""
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# replace the signing key with our own
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock()
config = default_config(name="keyclient")
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
- async def post_json(destination, path, data):
+ async def post_json(
+ destination: str, path: str, data: Optional[JsonDict] = None
+ ) -> Union[JsonDict, list]:
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path,
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel, self.site)
+ # channel is a `FakeChannel` but `HTTPChannel` is expected
+ req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
self.http_client2.post_json.side_effect = post_json
- def test_get_key(self):
+ def test_get_key(self) -> None:
"""Fetch a key belonging to a random server"""
# make up a key to be fetched.
testkey = signedjson.key.generate_signing_key("abc")
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
- def test_get_notary_key(self):
+ def test_get_notary_key(self) -> None:
"""Fetch a key belonging to the notary server"""
# make up a key to be fetched. We randomise the keyid to try to get it to
# appear before the key server signing key sometimes (otherwise we bail out
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
- def test_get_notary_keyserver_key(self):
+ def test_get_notary_keyserver_key(self) -> None:
"""Fetch the notary's keyserver key"""
# we expect hs1 to make a regular key request to itself
self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index f761e23f1b..c73179151a 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
}
- def tests(self):
+ def tests(self) -> None:
for hdr, expected in self.TEST_CASES.items():
res = get_filename_from_headers({b"Content-Disposition": [hdr]})
self.assertEqual(
res,
expected,
- "expected output for %s to be %s but was %s" % (hdr, expected, res),
+ f"expected output for {hdr!r} to be {expected} but was {res}",
)
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 913bc530aa..43e6f0f70a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -21,12 +21,12 @@ from tests import unittest
class MediaFilePathsTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
super().setUp()
self.filepaths = MediaFilePaths("/media_store")
- def test_local_media_filepath(self):
+ def test_local_media_filepath(self) -> None:
"""Test local media paths"""
self.assertEqual(
self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_local_media_thumbnail(self):
+ def test_local_media_thumbnail(self) -> None:
"""Test local media thumbnail paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_rel(
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_local_media_thumbnail_dir(self):
+ def test_local_media_thumbnail_dir(self) -> None:
"""Test local media thumbnail directory paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_remote_media_filepath(self):
+ def test_remote_media_filepath(self) -> None:
"""Test remote media paths"""
self.assertEqual(
self.filepaths.remote_media_filepath_rel(
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_remote_media_thumbnail(self):
+ def test_remote_media_thumbnail(self) -> None:
"""Test remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel(
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_remote_media_thumbnail_legacy(self):
+ def test_remote_media_thumbnail_legacy(self) -> None:
"""Test old-style remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel_legacy(
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
)
- def test_remote_media_thumbnail_dir(self):
+ def test_remote_media_thumbnail_dir(self) -> None:
"""Test remote media thumbnail directory paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_dir(
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_filepath(self):
+ def test_url_cache_filepath(self) -> None:
"""Test URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
)
- def test_url_cache_filepath_legacy(self):
+ def test_url_cache_filepath_legacy(self) -> None:
"""Test old-style URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_filepath_dirs_to_delete(self):
+ def test_url_cache_filepath_dirs_to_delete(self) -> None:
"""Test URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
["/media_store/url_cache/2020-01-02"],
)
- def test_url_cache_filepath_dirs_to_delete_legacy(self):
+ def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_url_cache_thumbnail(self):
+ def test_url_cache_thumbnail(self) -> None:
"""Test URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
)
- def test_url_cache_thumbnail_legacy(self):
+ def test_url_cache_thumbnail_legacy(self) -> None:
"""Test old-style URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
- def test_url_cache_thumbnail_directory(self):
+ def test_url_cache_thumbnail_directory(self) -> None:
"""Test URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
)
- def test_url_cache_thumbnail_directory_legacy(self):
+ def test_url_cache_thumbnail_directory_legacy(self) -> None:
"""Test old-style URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
- def test_url_cache_thumbnail_dirs_to_delete(self):
+ def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
"""Test URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+ def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_server_name_validation(self):
+ def test_server_name_validation(self) -> None:
"""Test validation of server names"""
self._test_path_validation(
[
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_file_id_validation(self):
+ def test_file_id_validation(self) -> None:
"""Test validation of local, remote and legacy URL cache file / media IDs"""
# File / media IDs get split into three parts to form paths, consisting of the
# first two characters, next two characters and rest of the ID.
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
invalid_values=invalid_file_ids,
)
- def test_url_cache_media_id_validation(self):
+ def test_url_cache_media_id_validation(self) -> None:
"""Test validation of URL cache media IDs"""
self._test_path_validation(
[
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_content_type_validation(self):
+ def test_content_type_validation(self) -> None:
"""Test validation of thumbnail content types"""
self._test_path_validation(
[
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
- def test_thumbnail_method_validation(self):
+ def test_thumbnail_method_validation(self) -> None:
"""Test validation of thumbnail methods"""
self._test_path_validation(
[
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
parameter: str,
valid_values: Iterable[str],
invalid_values: Iterable[str],
- ):
+ ) -> None:
"""Test that the specified methods validate the named parameter as expected
Args:
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index a4b57e3d1f..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
_get_html_media_encodings,
decode_body,
parse_html_to_open_graph,
- rebase_url,
summarize_paragraphs,
)
@@ -32,7 +31,7 @@ class SummarizeTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
- def test_long_summarize(self):
+ def test_long_summarize(self) -> None:
example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@@ -90,7 +89,7 @@ class SummarizeTestCase(unittest.TestCase):
" Tromsøya had a population of 36,088. Substantial parts of the urban…",
)
- def test_short_summarize(self):
+ def test_short_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -117,7 +116,7 @@ class SummarizeTestCase(unittest.TestCase):
" most of the year.",
)
- def test_small_then_large_summarize(self):
+ def test_small_then_large_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -150,7 +149,7 @@ class CalcOgTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
- def test_simple(self):
+ def test_simple(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -161,11 +160,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_comment(self):
+ def test_comment(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -177,11 +176,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_comment2(self):
+ def test_comment2(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -206,7 +205,7 @@ class CalcOgTestCase(unittest.TestCase):
},
)
- def test_script(self):
+ def test_script(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -218,11 +217,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_missing_title(self):
+ def test_missing_title(self) -> None:
html = b"""
<html>
<body>
@@ -232,11 +231,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
- def test_h1_as_title(self):
+ def test_h1_as_title(self) -> None:
html = b"""
<html>
<meta property="og:description" content="Some text."/>
@@ -247,11 +246,11 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
- def test_missing_title_and_broken_h1(self):
+ def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>
<body>
@@ -262,23 +261,23 @@ class CalcOgTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
- def test_empty(self):
+ def test_empty(self) -> None:
"""Test a body with no data in it."""
html = b""
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
- def test_no_tree(self):
+ def test_no_tree(self) -> None:
"""A valid body with no tree in it."""
html = b"\x00"
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
- def test_xml(self):
+ def test_xml(self) -> None:
"""Test decoding XML and ensure it works properly."""
# Note that the strip() call is important to ensure the xml tag starts
# at the initial byte.
@@ -290,10 +289,10 @@ class CalcOgTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html>
""".strip()
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_invalid_encoding(self):
+ def test_invalid_encoding(self) -> None:
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = b"""
<html>
@@ -304,10 +303,10 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
- def test_invalid_encoding2(self):
+ def test_invalid_encoding2(self) -> None:
"""A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title.
html = b"""
@@ -319,10 +318,10 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
- def test_windows_1252(self):
+ def test_windows_1252(self) -> None:
"""A body which uses cp1252, but doesn't declare that."""
html = b"""
<html>
@@ -333,12 +332,12 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+ og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
class MediaEncodingTestCase(unittest.TestCase):
- def test_meta_charset(self):
+ def test_meta_charset(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -363,7 +362,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_meta_charset_underscores(self):
+ def test_meta_charset_underscores(self) -> None:
"""A character encoding contains underscore."""
encodings = _get_html_media_encodings(
b"""
@@ -376,7 +375,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
- def test_xml_encoding(self):
+ def test_xml_encoding(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -388,7 +387,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_meta_xml_encoding(self):
+ def test_meta_xml_encoding(self) -> None:
"""Meta tags take precedence over XML encoding."""
encodings = _get_html_media_encodings(
b"""
@@ -402,7 +401,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
- def test_content_type(self):
+ def test_content_type(self) -> None:
"""A character encoding is found via the Content-Type header."""
# Test a few variations of the header.
headers = (
@@ -417,12 +416,12 @@ class MediaEncodingTestCase(unittest.TestCase):
encodings = _get_html_media_encodings(b"", header)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
- def test_fallback(self):
+ def test_fallback(self) -> None:
"""A character encoding cannot be found in the body or header."""
encodings = _get_html_media_encodings(b"", "text/html")
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
- def test_duplicates(self):
+ def test_duplicates(self) -> None:
"""Ensure each encoding is only attempted once."""
encodings = _get_html_media_encodings(
b"""
@@ -436,7 +435,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
- def test_unknown_invalid(self):
+ def test_unknown_invalid(self) -> None:
"""A character encoding should be ignored if it is unknown or invalid."""
encodings = _get_html_media_encodings(
b"""
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset="invalid"',
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
- def test_relative(self):
- """Relative URLs should be resolved based on the context of the base URL."""
- self.assertEqual(
- rebase_url("subpage", "https://example.com/foo/"),
- "https://example.com/foo/subpage",
- )
- self.assertEqual(
- rebase_url("sibling", "https://example.com/foo"),
- "https://example.com/sibling",
- )
- self.assertEqual(
- rebase_url("/bar", "https://example.com/foo/"),
- "https://example.com/bar",
- )
-
- def test_absolute(self):
- """Absolute URLs should not be modified."""
- self.assertEqual(
- rebase_url("https://alice.com/a/", "https://example.com/foo/"),
- "https://alice.com/a/",
- )
-
- def test_data(self):
- """Data URLs should not be modified."""
- self.assertEqual(
- rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
- "data:,Hello%2C%20World%21",
- )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index cba9be17c4..7204b2dfe0 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Optional
+from typing import Any, BinaryIO, Dict, List, Optional, Union
from unittest.mock import Mock
from urllib import parse
@@ -26,18 +26,24 @@ from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
+from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+from synapse.util import Clock
from tests import unittest
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import SMALL_PNG
from tests.utils import default_config
@@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
needs_threadpool = True
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
@@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
hs, self.primary_base_path, self.filepaths, storage_providers
)
- def test_ensure_media_is_in_local_cache(self):
+ def test_ensure_media_is_in_local_cache(self) -> None:
media_id = "some_media_id"
test_body = "Test\n"
@@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@@ -121,18 +127,18 @@ class _TestImage:
a 404 is expected.
"""
- data = attr.ib(type=bytes)
- content_type = attr.ib(type=bytes)
- extension = attr.ib(type=bytes)
- expected_cropped = attr.ib(type=Optional[bytes], default=None)
- expected_scaled = attr.ib(type=Optional[bytes], default=None)
- expected_found = attr.ib(default=True, type=bool)
+ data: bytes
+ content_type: bytes
+ extension: bytes
+ expected_cropped: Optional[bytes] = None
+ expected_scaled: Optional[bytes] = None
+ expected_found: bool = True
@parameterized_class(
("test_image",),
[
- # smoll png
+ # small png
(
_TestImage(
SMALL_PNG,
@@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches = []
- def get_file(destination, path, output_stream, args=None, max_size=None):
+ def get_file(
+ destination: str,
+ path: str,
+ output_stream: BinaryIO,
+ args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ max_size: Optional[int] = None,
+ ) -> Deferred:
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
@@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.media_id = "example.com/12345"
- def _req(self, content_disposition, include_content_type=True):
-
+ def _req(
+ self, content_disposition: Optional[bytes], include_content_type: bool = True
+ ) -> FakeChannel:
channel = make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
@@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return channel
- def test_handle_missing_content_type(self):
+ def test_handle_missing_content_type(self) -> None:
channel = self._req(
b"inline; filename=out" + self.test_image.extension,
include_content_type=False,
@@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
)
- def test_disposition_filename_ascii(self):
+ def test_disposition_filename_ascii(self) -> None:
"""
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
@@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename=out" + self.test_image.extension],
)
- def test_disposition_filenamestar_utf8escaped(self):
+ def test_disposition_filenamestar_utf8escaped(self) -> None:
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the
@@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
- def test_disposition_none(self):
+ def test_disposition_none(self) -> None:
"""
If there is no filename, one isn't passed on in the Content-Disposition
of the request.
@@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
- def test_thumbnail_crop(self):
+ def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
- def test_thumbnail_scale(self):
+ def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
- def test_invalid_type(self):
+ def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
- def test_no_thumbnail_crop(self):
+ def test_no_thumbnail_crop(self) -> None:
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
@@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
- def test_no_thumbnail_scale(self):
+ def test_no_thumbnail_scale(self) -> None:
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)
- def test_thumbnail_repeated_thumbnail(self):
+ def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
@@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.result["body"],
)
- def _test_thumbnail(self, method, expected_body, expected_found):
+ def _test_thumbnail(
+ self, method: str, expected_body: Optional[bytes], expected_found: bool
+ ) -> None:
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
@@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
- def test_same_quality(self, method, desired_size):
+ def test_same_quality(self, method: str, desired_size: int) -> None:
"""Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen."""
@@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
)
- def test_x_robots_tag_header(self):
+ def test_x_robots_tag_header(self) -> None:
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media.
@@ -540,29 +555,38 @@ class TestSpamChecker:
`evil`.
"""
- def __init__(self, config, api):
+ def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
self.config = config
self.api = api
- def parse_config(config):
+ def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
- async def check_event_for_spam(self, foo):
+ async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
return False # allow all events
- async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ async def user_may_invite(
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ room_id: str,
+ ) -> bool:
return True # allow all invites
- async def user_may_create_room(self, userid):
+ async def user_may_create_room(self, userid: str) -> bool:
return True # allow all room creations
- async def user_may_create_room_alias(self, userid, room_alias):
+ async def user_may_create_room_alias(
+ self, userid: str, room_alias: RoomAlias
+ ) -> bool:
return True # allow all room aliases
- async def user_may_publish_room(self, userid, room_id):
+ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
return True # allow publishing of all rooms
- async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ async def check_media_file_for_spam(
+ self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+ ) -> bool:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
@@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
@@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
load_legacy_spam_checkers(hs)
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = default_config("test")
config.update(
@@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
return config
- def test_upload_innocent(self):
+ def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
- def test_upload_ban(self):
+ def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index 048d0ca44a..f38d7225f8 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -16,7 +16,7 @@ import json
from twisted.test.proto_helpers import MemoryReactor
-from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
class OEmbedTests(HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
- self.oembed = OEmbedProvider(homeserver)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.oembed = OEmbedProvider(hs)
- def parse_response(self, response: JsonDict):
+ def parse_response(self, response: JsonDict) -> OEmbedResult:
return self.oembed.parse_oembed_response(
"https://test", json.dumps(response).encode("utf-8")
)
- def test_version(self):
+ def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
result = self.parse_response({"version": version, "type": "link"})
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index da2c533260..5148c39874 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -16,16 +16,21 @@ import base64
import json
import os
import re
+from typing import Any, Dict, Optional, Sequence, Tuple, Type
from urllib.parse import urlencode
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
@@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["url_preview_enabled"] = True
@@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
- self.lookups = {}
+ self.lookups: Dict[str, Any] = {}
class Resolver:
def resolveHostName(
_self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics="TCP",
- ):
+ resolutionReceiver: IResolutionReceiver,
+ hostName: str,
+ portNumber: int = 0,
+ addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+ transportSemantics: str = "TCP",
+ ) -> IResolutionReceiver:
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
@@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
resolutionReceiver.resolutionComplete()
return resolutionReceiver
- self.reactor.nameResolver = Resolver()
+ self.reactor.nameResolver = Resolver() # type: ignore[assignment]
- def create_test_resource(self):
+ def create_test_resource(self) -> MediaRepositoryResource:
return self.hs.get_media_repository_resource()
def _assert_small_png(self, json_body: JsonDict) -> None:
@@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(json_body["og:image:type"], "image/png")
self.assertEqual(json_body["matrix:image:size"], 67)
- def test_cache_returns_correct_type(self):
+ def test_cache_returns_correct_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
channel = self.make_request(
@@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_non_ascii_preview_httpequiv(self):
+ def test_non_ascii_preview_httpequiv(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
- def test_video_rejected(self):
+ def test_video_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_audio_rejected(self):
+ def test_audio_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_non_ascii_preview_content_type(self):
+ def test_non_ascii_preview_content_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
- def test_overlong_title(self):
+ def test_overlong_title(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# We should only see the `og:description` field, as `title` is too long and should be stripped out
self.assertCountEqual(["og:description"], res.keys())
- def test_ipaddr(self):
+ def test_ipaddr(self) -> None:
"""
IP addresses can be previewed directly.
"""
@@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_blacklisted_ip_specific(self):
+ def test_blacklisted_ip_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_range(self):
+ def test_blacklisted_ip_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_specific_direct(self):
+ def test_blacklisted_ip_specific_direct(self) -> None:
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
@@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 403)
- def test_blacklisted_ip_range_direct(self):
+ def test_blacklisted_ip_range_direct(self) -> None:
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
@@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ip_range_whitelisted_ip(self):
+ def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
"""
Blacklisted but then subsequently whitelisted IP addresses can be
spidered.
@@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
- def test_blacklisted_ip_with_external_ip(self):
+ def test_blacklisted_ip_with_external_ip(self) -> None:
"""
If a hostname resolves a blacklisted IP, even if there's a
non-blacklisted one, it will be rejected.
@@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ipv6_specific(self):
+ def test_blacklisted_ipv6_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_blacklisted_ipv6_range(self):
+ def test_blacklisted_ipv6_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_OPTIONS(self):
+ def test_OPTIONS(self) -> None:
"""
OPTIONS returns the OPTIONS.
"""
@@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
- def test_accept_language_config_option(self):
+ def test_accept_language_config_option(self) -> None:
"""
Accept-Language header is sent to the remote server
"""
@@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
- def test_data_url(self):
+ def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
"""
@@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 500)
- def test_inline_data_url(self):
+ def test_inline_data_url(self) -> None:
"""
An inline image (as a data URL) should be parsed properly.
"""
@@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self._assert_small_png(channel.json_body)
- def test_oembed_photo(self):
+ def test_oembed_photo(self) -> None:
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
self._assert_small_png(body)
- def test_oembed_rich(self):
+ def test_oembed_rich(self) -> None:
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_oembed_format(self):
+ def test_oembed_format(self) -> None:
"""Test an oEmbed endpoint which requires the format in the URL."""
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
@@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
- def test_oembed_autodiscovery(self):
+ def test_oembed_autodiscovery(self) -> None:
"""
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
1. Request a preview of a URL which is not known to the oEmbed code.
@@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self._assert_small_png(body)
- def _download_image(self):
+ def _download_image(self) -> Tuple[str, str]:
"""Downloads an image into the URL cache.
Returns:
A (host, media_id) tuple representing the MXC URI of the image.
@@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port)
return host, media_id
- def test_storage_providers_exclude_files(self):
+ def test_storage_providers_exclude_files(self) -> None:
"""Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider",
)
- def test_storage_providers_exclude_thumbnails(self):
+ def test_storage_providers_exclude_thumbnails(self) -> None:
"""Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
- def test_cache_expiry(self):
+ def test_cache_expiry(self) -> None:
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
self.preview_url.clock = MockClock()
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 01d48c3860..da325955f8 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from http import HTTPStatus
from synapse.rest.health import HealthResource
@@ -19,12 +19,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
- def create_test_resource(self):
+ def create_test_resource(self) -> HealthResource:
# replace the JsonResource with a HealthResource.
return HealthResource()
- def test_health(self):
+ def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 118aa93a32..11f78f52b8 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -19,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
- def create_test_resource(self):
+ def create_test_resource(self) -> Resource:
# replace the JsonResource with a Resource wrapping the WellKnownResource
res = Resource()
res.putChild(b".well-known", well_known_resource(self.hs))
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
"default_identity_server": "https://testis",
}
)
- def test_client_well_known(self):
+ def test_client_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
"public_baseurl": None,
}
)
- def test_client_well_known_no_public_baseurl(self):
+ def test_client_well_known_no_public_baseurl(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@unittest.override_config({"serve_server_wellknown": True})
- def test_server_well_known(self):
+ def test_server_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
)
- def test_server_well_known_disabled(self):
+ def test_server_well_known_disabled(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, 404)
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|