diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[users[2]],
)
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_no_account_validity(self) -> None:
+ """Tests that if we decide to include account validity in the response but no
+ account validity 'is_user_expired' callback is provided, we default to marking all
+ users as not expired.
+ """
+ user = self.register_user("someuser", "password")
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": False,
+ },
+ },
+ expected_failures=[],
+ )
+
+ @unittest.override_config(
+ {
+ "use_account_validity_in_account_status": True,
+ }
+ )
+ def test_account_validity_expired(self) -> None:
+ """Test that if we decide to include account validity in the response and the user
+ is expired, we return the correct info.
+ """
+ user = self.register_user("someuser", "password")
+
+ async def is_expired(user_id: str) -> bool:
+ # We can't blindly say everyone is expired, otherwise the request to get the
+ # account status will fail.
+ return UserID.from_string(user_id).localpart == "someuser"
+
+ self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+ is_expired
+ )
+
+ self._test_status(
+ users=[user],
+ expected_statuses={
+ user: {
+ "exists": True,
+ "deactivated": False,
+ "org.matrix.expired": True,
+ },
+ },
+ expected_failures=[],
+ )
+
def _test_status(
self,
users: Optional[List[str]],
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 3818b7b14b..7b7d283bb6 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -14,7 +14,7 @@
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.rest.client import login, room, shared_rooms
+from synapse.rest.client import login, mutual_rooms, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -22,16 +22,16 @@ from tests import unittest
from tests.server import FakeChannel
-class UserSharedRoomsTest(unittest.HomeserverTestCase):
+class UserMutualRoomsTest(unittest.HomeserverTestCase):
"""
- Tests the UserSharedRoomsServlet.
+ Tests the UserMutualRoomsServlet.
"""
servlets = [
login.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
- shared_rooms.register_servlets,
+ mutual_rooms.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_user_directory_handler()
- def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
+ def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
"GET",
- "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
% other_user,
access_token=token,
)
@@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
A room should show up in the shared list of rooms between two users
if it is public.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
def test_shared_room_list_private(self) -> None:
"""
A room should show up in the shared list of rooms between two users
if it is private.
"""
- self._check_shared_rooms_with(
+ self._check_mutual_rooms_with(
room_one_is_public=False, room_two_is_public=False
)
@@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
The shared room list between two users should contain both public and private
rooms.
"""
- self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+ self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
- def _check_shared_rooms_with(
+ def _check_mutual_rooms_with(
self, room_one_is_public: bool, room_two_is_public: bool
) -> None:
"""Checks that shared public or private rooms between two users appear in
@@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
# Check shared rooms from user1's perspective.
# We should see the one room in common
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room_id_one)
@@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room_id_two, user=u2, tok=u2_token)
# Check shared rooms again. We should now see both rooms.
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 2)
for room_id_id in channel.json_body["joined"]:
@@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.join(room, user=u2, tok=u2_token)
# Assert user directory is not empty
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 1)
self.assertEqual(channel.json_body["joined"][0], room)
@@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
self.helper.leave(room, user=u1, tok=u1_token)
# Check user1's view of shared rooms with user2
- channel = self._get_shared_rooms(u1_token, u2)
+ channel = self._get_mutual_rooms(u1_token, u2)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
# Check user2's view of shared rooms with user1
- channel = self._get_shared_rooms(u2_token, u1)
+ channel = self._get_mutual_rooms(u2_token, u1)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 709f851a38..fe97a0b3dd 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,17 +15,16 @@
import itertools
import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.server import HomeServer
-from synapse.storage.relations import RelationPaginationToken
-from synapse.types import JsonDict, StreamToken
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -80,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
+ expected_response_code: int = 200,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
@@ -116,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
content,
access_token=access_token,
)
+ self.assertEqual(expected_response_code, channel.code, channel.json_body)
return channel
+ def _get_related_events(self) -> List[str]:
+ """
+ Requests /relations on the parent ID and returns a list of event IDs.
+ """
+ # Request the relations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+ def _get_bundled_aggregations(self) -> JsonDict:
+ """
+ Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+ """
+ # Fetch the bundled aggregations of the event.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ return channel.json_body["unsigned"].get("m.relations", {})
+
+ def _get_aggregations(self) -> List[JsonDict]:
+ """Request /aggregations on the parent ID and includes the returned chunk."""
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ return channel.json_body["chunk"]
+
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
"""Tests that sending a relation works."""
-
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
-
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -152,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_deny_invalid_event(self) -> None:
"""Test that we deny relations on non-existant events"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.ANNOTATION,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
# Unless that event is referenced from another event!
self.get_success(
@@ -172,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
desc="test_deny_invalid_event",
)
)
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
parent_id="foo",
content={"body": "foo", "msgtype": "m.text"},
)
- self.assertEqual(200, channel.code, channel.json_body)
def test_deny_invalid_room(self) -> None:
"""Test that we deny relations on non-existant events"""
@@ -188,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
parent_id = res["event_id"]
# Attempt to send an annotation to that event.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ parent_id=parent_id,
+ key="A",
+ expected_response_code=400,
)
- self.assertEqual(400, channel.code, channel.json_body)
def test_deny_double_react(self) -> None:
"""Test that we deny relations on membership events"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(400, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+ )
def test_deny_forked_thread(self) -> None:
"""It is invalid to start a thread off a thread."""
@@ -209,386 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
- )
- self.assertEqual(400, channel.code, channel.json_body)
-
- def test_basic_paginate_relations(self) -> None:
- """Tests that calling pagination API correctly the latest relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- first_annotation_id = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
- second_annotation_id = channel.json_body["event_id"]
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the latest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": second_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
- # Make sure next_batch has something in it that looks like it could be a
- # valid token.
- self.assertIsInstance(
- channel.json_body.get("next_batch"), str, channel.json_body
- )
-
- # Request the relations again, but with a different direction.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect to get back a single pagination result, which is the earliest
- # full relation event we sent above.
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
- self.assert_dict(
- {
- "event_id": first_annotation_id,
- "sender": self.user_id,
- "type": "m.reaction",
- },
- channel.json_body["chunk"][0],
- )
-
- def _stream_token_to_relation_token(self, token: str) -> str:
- """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
- room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
- return self.get_success(
- RelationPaginationToken(
- topological=room_key.topological, stream=room_key.stream
- ).to_string(self.store)
- )
-
- def test_repeated_paginate_relations(self) -> None:
- """Test that if we paginate using a limit and tokens then we get the
- expected events.
- """
-
- expected_event_ids = []
- for idx in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- prev_token = ""
- found_event_ids: List[str] = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- def test_pagination_from_sync_and_messages(self) -> None:
- """Pagination tokens from /sync and /messages can be used to paginate /relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
- # Send an event after the relation events.
- self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
- # Request /sync, limiting it such that only the latest event is returned
- # (and not the relation).
- filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- sync_prev_batch = room_timeline["prev_batch"]
- self.assertIsNotNone(sync_prev_batch)
- # Ensure the relation event is not in the batch returned from /sync.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
- )
-
- # Request /messages, limiting it such that only the latest event is
- # returned (and not the relation).
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b&limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- messages_end = channel.json_body["end"]
- self.assertIsNotNone(messages_end)
- # Ensure the relation event is not in the chunk returned from /messages.
- self.assertNotIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ expected_response_code=400,
)
- # Request /relations with the pagination tokens received from both the
- # /sync and /messages responses above, in turn.
- #
- # This is a tiny bit silly since the client wouldn't know the parent ID
- # from the requests above; consider the parent ID to be known from a
- # previous /sync.
- for from_token in (sync_prev_batch, messages_end):
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # The relation should be in the returned chunk.
- self.assertIn(
- annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
- )
-
- def test_aggregation_pagination_groups(self) -> None:
- """Test that we can paginate annotation groups correctly."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
- for key in itertools.chain.from_iterable(
- itertools.repeat(key, num) for key, num in sent_groups.items()
- ):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key=key,
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- idx += 1
- idx %= len(access_tokens)
-
- prev_token: Optional[str] = None
- found_groups: Dict[str, int] = {}
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- for groups in channel.json_body["chunk"]:
- # We only expect reactions
- self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
- # We should only see each key once
- self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
- found_groups[groups["key"]] = groups["count"]
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- self.assertEqual(sent_groups, found_groups)
-
- def test_aggregation_pagination_within_group(self) -> None:
- """Test that we can paginate within an annotation group."""
-
- # We need to create ten separate users to send each reaction.
- access_tokens = [self.user_token, self.user2_token]
- idx = 0
- while len(access_tokens) < 10:
- user_id, token = self._create_user("test" + str(idx))
- idx += 1
-
- self.helper.join(self.room, user=user_id, tok=token)
- access_tokens.append(token)
-
- idx = 0
- expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="👍",
- access_token=access_tokens[idx],
- )
- self.assertEqual(200, channel.code, channel.json_body)
- expected_event_ids.append(channel.json_body["event_id"])
-
- idx += 1
-
- # Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- prev_token = ""
- found_event_ids: List[str] = []
- encoded_key = urllib.parse.quote_plus("👍".encode())
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + prev_token
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
- # Reset and try again, but convert the tokens to the legacy format.
- prev_token = ""
- found_event_ids = []
- for _ in range(20):
- from_token = ""
- if prev_token:
- from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}"
- f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
- f"/m.reaction/{encoded_key}?limit=1{from_token}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
- found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
- next_batch = channel.json_body.get("next_batch")
-
- self.assertNotEqual(prev_token, next_batch)
- prev_token = next_batch
-
- if not prev_token:
- break
-
- # We paginated backwards, so reverse
- found_event_ids.reverse()
- self.assertEqual(found_event_ids, expected_event_ids)
-
def test_aggregation(self) -> None:
"""Test that annotations get correctly aggregated."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
channel = self.make_request(
"GET",
@@ -618,220 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
- @unittest.override_config(
- {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
- )
- def test_bundled_aggregations(self) -> None:
- """
- Test that annotations, references, and threads get correctly bundled.
-
- Note that this doesn't test against /relations since only thread relations
- get bundled via that API. See test_aggregation_get_event_for_thread.
-
- See test_edit for a similar test for edits.
- """
- # Setup by sending a variety of relations.
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_1 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- reply_2 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_2 = channel.json_body["event_id"]
-
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- relations_dict.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- relations_dict[RelationTypes.ANNOTATION],
- )
-
- self.assertEqual(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- relations_dict[RelationTypes.REFERENCE],
- )
-
- self.assertEqual(
- 2,
- relations_dict[RelationTypes.THREAD].get("count"),
- )
- self.assertTrue(
- relations_dict[RelationTypes.THREAD].get("current_user_participated")
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "room_id": self.room,
- "sender": self.user_id,
- "type": "m.room.test",
- "user_id": self.user_id,
- },
- relations_dict[RelationTypes.THREAD].get("latest_event"),
- )
-
- # Request the event directly.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body)
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
- # Request search.
- channel = self.make_request(
- "POST",
- "/search",
- # Search term matches the parent message.
- content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- chunk = [
- result["result"]
- for result in channel.json_body["search_categories"]["room_events"][
- "results"
- ]
- ]
- assert_bundle(self._find_event_in_chunk(chunk))
-
- def test_aggregation_get_event_for_annotation(self) -> None:
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
- def test_aggregation_get_event_for_thread(self) -> None:
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEqual(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -953,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
def assert_bundle(event_json: JsonDict) -> None:
@@ -1030,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
shouldn't be allowed, are correctly handled.
"""
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1039,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
@@ -1047,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message.WRONG_TYPE",
content={
@@ -1060,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
@@ -1091,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
reply = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1101,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=reply,
)
- self.assertEqual(200, channel.code, channel.json_body)
-
edit_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1138,7 +598,6 @@ class RelationsTestCase(BaseRelationsTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self) -> None:
"""Test that editing a thread works."""
@@ -1148,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
- self.assertEqual(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
@@ -1190,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
"m.new_content": new_body,
},
)
- self.assertEqual(200, channel.code, channel.json_body)
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -1204,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
},
parent_id=edit_event_id,
)
- self.assertEqual(200, channel.code, channel.json_body)
# Request the original event.
channel = self.make_request(
@@ -1231,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1272,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- return event
-
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_good = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
- self.assertEqual(200, channel.code, channel.json_body)
annotation_event_id_bad = channel.json_body["event_id"]
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- self.assertEqual(200, channel.code, channel.json_body)
thread_event_id = channel.json_body["event_id"]
# Clean-up the table as if the inserts did not happen during event creation.
@@ -1345,8 +786,638 @@ class RelationsTestCase(BaseRelationsTestCase):
)
+class RelationPaginationTestCase(BaseRelationsTestCase):
+ def test_basic_paginate_relations(self) -> None:
+ """Tests that calling pagination API correctly the latest relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ first_annotation_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ second_annotation_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the latest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": second_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEqual(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), str, channel.json_body
+ )
+
+ # Request the relations again, but with a different direction.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations"
+ f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the earliest
+ # full relation event we sent above.
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {
+ "event_id": first_annotation_id,
+ "sender": self.user_id,
+ "type": "m.reaction",
+ },
+ channel.json_body["chunk"][0],
+ )
+
+ def test_repeated_paginate_relations(self) -> None:
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for idx in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+ def test_pagination_from_sync_and_messages(self) -> None:
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+ )
+
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ def test_aggregation_pagination_groups(self) -> None:
+ """Test that we can paginate annotation groups correctly."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self) -> None:
+ """Test that we can paginate within an annotation group."""
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="👍",
+ access_token=access_tokens[idx],
+ )
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+ prev_token = ""
+ found_event_ids: List[str] = []
+ encoded_key = urllib.parse.quote_plus("👍".encode())
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+ """
+ See RelationsTestCase.test_edit for a similar test for edits.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+ """
+
+ def _test_bundled_aggregations(
+ self,
+ relation_type: str,
+ assertion_callable: Callable[[JsonDict], None],
+ expected_db_txn_for_event: int,
+ ) -> None:
+ """
+ Makes requests to various endpoints which should include bundled aggregations
+ and then calls an assertion function on the bundled aggregations.
+
+ Args:
+ relation_type: The field to search for in the `m.relations` field in unsigned.
+ assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+ for relation-specific assertions.
+ expected_db_txn_for_event: The number of database transactions which
+ are expected for a call to /event/.
+ """
+
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+
+ # Ensure the fields are as expected.
+ self.assertCountEqual(relations_dict.keys(), (relation_type,))
+ assertion_callable(relations_dict[relation_type])
+
+ # Request the event directly.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync.
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_annotation(self) -> None:
+ """
+ Test that annotations get correctly bundled.
+ """
+ # Setup by sending a variety of relations.
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_reference(self) -> None:
+ """
+ Test that references get correctly bundled.
+ """
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_thread(self) -> None:
+ """
+ Test that threads get correctly bundled.
+ """
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertTrue(bundled_aggregations.get("current_user_participated"))
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ },
+ bundled_aggregations.get("latest_event"),
+ )
+
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+ def test_aggregation_get_event_for_annotation(self) -> None:
+ """Test that annotations do not get bundled aggregations included
+ when directly requested.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ annotation_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{annotation_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+ def test_aggregation_get_event_for_thread(self) -> None:
+ """Test that threads get bundled aggregations included when directly requested."""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{thread_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ # It should also be included when the entire thread is requested.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+ thread_message = channel.json_body["chunk"][0]
+ self.assertEqual(
+ thread_message["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ def test_bundled_aggregations_with_filter(self) -> None:
+ """
+ If "unsigned" is an omitted field (due to filtering), adding the bundled
+ aggregations should not break.
+
+ Note that the spec allows for a server to return additional fields beyond
+ what is specified.
+ """
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+ # Note that the sync filter does not include "unsigned" as a field.
+ filter = urllib.parse.quote_plus(
+ b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # Ensure the timeline is limited, find the parent event.
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+ # Ensure there's bundled aggregations on it.
+ self.assertIn("unsigned", parent_event)
+ self.assertIn("m.relations", parent_event["unsigned"])
+
+
+class RelationIgnoredUserTestCase(BaseRelationsTestCase):
+ """Relations sent from an ignored user should be ignored."""
+
+ def _test_ignored_user(
+ self, allowed_event_ids: List[str], ignored_event_ids: List[str]
+ ) -> None:
+ """
+ Fetch the relations and ensure they're all there, then ignore user2, and
+ repeat.
+ """
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
+
+ # Ignore user2 and re-do the requests.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Get the relations.
+ event_ids = self._get_related_events()
+ self.assertCountEqual(event_ids, allowed_event_ids)
+
+ def test_annotation(self) -> None:
+ """Annotations should ignore"""
+ # Send 2 from us, 2 from the to be ignored user.
+ allowed_event_ids = []
+ ignored_event_ids = []
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
+ allowed_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="a",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key="c",
+ access_token=self.user2_token,
+ )
+ ignored_event_ids.append(channel.json_body["event_id"])
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_reference(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+ def test_thread(self) -> None:
+ """Annotations should ignore"""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ allowed_event_ids = [channel.json_body["event_id"]]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ ignored_event_ids = [channel.json_body["event_id"]]
+
+ self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+
class RelationRedactionTestCase(BaseRelationsTestCase):
- """Test the behaviour of relations when the parent or child event is redacted."""
+ """
+ Test the behaviour of relations when the parent or child event is redacted.
+
+ The behaviour of each relation type is subtly different which causes the tests
+ to be a bit repetitive, they follow a naming scheme of:
+
+ test_redact_(relation|parent)_{relation_type}
+
+ The first bit of "relation" means that the event with the relation defined
+ on it (the child event) is to be redacted. A "parent" means that the target
+ of the relation (the parent event) is to be redacted.
+
+ The relation_type describes which type of relation is under test (i.e. it is
+ related to the value of rel_type in the event content).
+ """
def _redact(self, event_id: str) -> None:
channel = self.make_request(
@@ -1358,40 +1429,116 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
def test_redact_relation_annotation(self) -> None:
- """Test that annotations of an event are properly handled after the
+ """
+ Test that annotations of an event are properly handled after the
annotation is redacted.
+
+ The redacted relation should not be included in bundled aggregations or
+ the response to relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
- self.assertEqual(200, channel.code, channel.json_body)
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Both relations should exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+ )
+
+ # Both relations appear in the aggregation.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
# Redact one of the reactions.
self._redact(to_redact_event_id)
- # Ensure that the aggregations are correct.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
- access_token=self.user_token,
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
- self.assertEqual(200, channel.code, channel.json_body)
+ # The unredacted aggregation should still exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+ def test_redact_relation_thread(self) -> None:
+ """
+ Test that thread replies are properly handled after the thread reply redacted.
+
+ The redacted event should not be included in bundled aggregations or
+ the response to relations.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ unredacted_event_id = channel.json_body["event_id"]
+
+ # Note that the *last* event in the thread is redacted, as that gets
+ # included in the bundled aggregation.
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 2", "msgtype": "m.text"},
+ )
+ to_redact_event_id = channel.json_body["event_id"]
+
+ # Both relations exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 2,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event returned is the event that will be redacted.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ to_redact_event_id,
+ )
+
+ # Redact one of the reactions.
+ self._redact(to_redact_event_id)
+
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [unredacted_event_id])
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ # And the latest event is now the unredacted event.
self.assertEqual(
- channel.json_body,
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ unredacted_event_id,
)
- def test_redact_relation_edit(self) -> None:
+ def test_redact_parent_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
# Add a relation
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=self.parent_id,
@@ -1401,54 +1548,83 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
- self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}/m.replace/m.room.message",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.REPLACE, relations)
# Redact the original event
self._redact(self.parent_id)
- # Try to check for remaining m.replace relations
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations"
- f"/{self.parent_id}/m.replace/m.room.message",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
+ # The relations are not returned.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 0)
+ self.assertEqual(relations, {})
- # Check that no relations are returned
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
-
- def test_redact_parent(self) -> None:
- """Test that annotations of an event are redacted when the original event
+ def test_redact_parent_annotation(self) -> None:
+ """Test that annotations of an event are viewable when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
- self.assertEqual(200, channel.code, channel.json_body)
+ related_event_id = channel.json_body["event_id"]
+
+ # The relations should exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEqual(len(event_ids), 1)
+ self.assertIn(RelationTypes.ANNOTATION, relations)
+
+ # The aggregation should exist.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
# Redact the original event.
self._redact(self.parent_id)
- # Check that aggregations returns zero
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
- access_token=self.user_token,
+ # The relations are returned.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(event_ids, [related_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIn("chunk", channel.json_body)
- self.assertEqual(channel.json_body["chunk"], [])
+ # There's nothing to aggregate.
+ chunk = self._get_aggregations()
+ self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_redact_parent_thread(self) -> None:
+ """
+ Test that thread replies are still available when the root event is redacted.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ related_event_id = channel.json_body["event_id"]
+
+ # Redact one of the reactions.
+ self._redact(self.parent_id)
+
+ # The unredacted relation should still exist.
+ event_ids = self._get_related_events()
+ relations = self._get_bundled_aggregations()
+ self.assertEquals(len(event_ids), 1)
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ related_event_id,
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f3bf8d0934..7b8fe6d025 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -24,6 +24,7 @@ from synapse.util import Clock
from synapse.visibility import filter_events_for_client
from tests import unittest
+from tests.unittest import override_config
one_hour_ms = 3600000
one_day_ms = one_hour_ms * 24
@@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["retention"] = {
+
+ # merge this default retention config with anything that was specified in
+ # @override_config
+ retention_config = {
"enabled": True,
"default_policy": {
"min_lifetime": one_day_ms,
@@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase):
"allowed_lifetime_min": one_day_ms,
"allowed_lifetime_max": one_day_ms * 3,
}
+ retention_config.update(config.get("retention", {}))
+ config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
@@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 2)
+ @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
def test_visibility(self) -> None:
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
- outdated events
+ outdated events, even if the purge job hasn't got to them yet.
+
+ We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
storage = self.hs.get_storage()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
- events = []
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
-
- # Get the event from the store so that we end up with a FrozenEvent that we can
- # give to filter_events_for_client. We need to do this now because the event won't
- # be in the database anymore after it has expired.
- events.append(self.get_success(store.get_event(resp.get("event_id"))))
+ first_event_id = resp.get("event_id")
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
-
valid_event_id = resp.get("event_id")
- events.append(self.get_success(store.get_event(valid_event_id)))
-
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
- # Run filter_events_for_client with our list of FrozenEvents.
+ # Fetch the events, and run filter_events_for_client on them
+ events = self.get_success(
+ store.get_events_as_list([first_event_id, valid_event_id])
+ )
+ self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 37866ee330..3a9617d6da 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
- filter = {"io.element.relation_senders": [self.second_user_id]}
+ filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which third user reacted to.
- filter = {"io.element.relation_senders": [self.third_user_id]}
+ filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which either user reacted to.
- filter = {
- "io.element.relation_senders": [self.second_user_id, self.third_user_id]
- }
+ filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
@@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_type(self) -> None:
# Messages which have annotations.
- filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which have references.
- filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which have either annotations or references.
filter = {
- "io.element.relation_types": [
+ "related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
@@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
- "io.element.relation_senders": [self.second_user_id],
- "io.element.relation_types": [RelationTypes.ANNOTATION],
+ "related_by_senders": [self.second_user_id],
+ "related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 58f1ea11b7..e7de67e3a3 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(args[0], user_id)
self.assertFalse(args[1])
self.assertTrue(args[2])
+
+ def test_check_can_deactivate_user(self) -> None:
+ """Tests that the on_user_deactivation_status_changed module callback is called
+ correctly when processing a user's deactivation.
+ """
+ # Register a mocked callback.
+ deactivation_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_deactivate_user_callbacks.append(
+ deactivation_mock,
+ )
+
+ # Register a user that we'll deactivate.
+ user_id = self.register_user("altan", "password")
+ tok = self.login("altan", "password")
+
+ # Deactivate that user.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/account/deactivate",
+ {
+ "auth": {
+ "type": LoginType.PASSWORD,
+ "password": "password",
+ "identifier": {
+ "type": "m.id.user",
+ "user": user_id,
+ },
+ },
+ "erase": True,
+ },
+ access_token=tok,
+ )
+
+ # Check that the deactivation was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ deactivation_mock.assert_called_once()
+ args = deactivation_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], user_id)
+
+ # Check that the request was not made by an admin
+ self.assertEqual(args[1], False)
+
+ def test_check_can_deactivate_user_admin(self) -> None:
+ """Tests that the on_user_deactivation_status_changed module callback is called
+ correctly when processing a user's deactivation triggered by a server admin.
+ """
+ # Register a mocked callback.
+ deactivation_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_deactivate_user_callbacks.append(
+ deactivation_mock,
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Register a user that we'll deactivate.
+ user_id = self.register_user("altan", "password")
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {"deactivated": True},
+ access_token=admin_tok,
+ )
+
+ # Check that the deactivation was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ deactivation_mock.assert_called_once()
+ args = deactivation_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], user_id)
+
+ # Check that the mock was made by an admin
+ self.assertEqual(args[1], True)
+
+ def test_check_can_shutdown_room(self) -> None:
+ """Tests that the check_can_shutdown_room module callback is called
+ correctly when processing an admin's shutdown room request.
+ """
+ # Register a mocked callback.
+ shutdown_mock = Mock(return_value=make_awaitable(False))
+ third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules._check_can_shutdown_room_callbacks.append(
+ shutdown_mock,
+ )
+
+ # Register an admin user.
+ admin_user_id = self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Shutdown the room.
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v2/rooms/%s" % self.room_id,
+ {},
+ access_token=admin_tok,
+ )
+
+ # Check that the shutdown was blocked
+ self.assertEqual(channel.code, 403, channel.json_body)
+
+ # Check that the mock was called once.
+ shutdown_mock.assert_called_once()
+ args = shutdown_mock.call_args[0]
+
+ # Check that the mock was called with the right user ID
+ self.assertEqual(args[0], admin_user_id)
+
+ # Check that the mock was called with the right room ID
+ self.assertEqual(args[1], self.room_id)
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3b5747cb12..8d8251b2ac 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,3 +1,18 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
from unittest.mock import Mock, call
from twisted.internet import defer, reactor
@@ -11,14 +26,14 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
- self.mock_http_response = (200, "GOOD JOB!")
+ self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks
|