diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..78c2fb86b9 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,15 +13,15 @@
# limitations under the License.
import itertools
-import json
-import urllib
-from typing import Optional
+import urllib.parse
+from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room
from tests import unittest
+from tests.server import FakeChannel
class RelationsTestCase(unittest.HomeserverTestCase):
@@ -34,16 +34,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
hijack_auth = False
- def make_homeserver(self, reactor, clock):
+ def default_config(self) -> dict:
# We need to enable msc1849 support for aggregations
- config = self.default_config()
+ config = super().default_config()
config["experimental_msc1849_support_enabled"] = True
# We enable frozen dicts as relations/edits change event contents, so we
# want to test that we don't modify the events in the caches.
config["use_frozen_dicts"] = True
- return self.setup_test_homeserver(config=config)
+ return config
def prepare(self, reactor, clock, hs):
self.user_id, self.user_token = self._create_user("alice")
@@ -101,10 +101,10 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
@@ -141,13 +141,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""
expected_event_ids = []
- for _ in range(10):
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ for idx in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+ )
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token = None
- found_event_ids = []
+ prev_token: Optional[str] = None
+ found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
if prev_token:
@@ -203,8 +205,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
idx += 1
idx %= len(access_tokens)
- prev_token = None
- found_groups = {}
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
for _ in range(20):
from_token = ""
if prev_token:
@@ -270,8 +272,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
- prev_token = None
- found_event_ids = []
+ prev_token: Optional[str] = None
+ found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
from_token = ""
@@ -386,8 +388,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(400, channel.code, channel.json_body)
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_aggregation_get_event(self):
- """Test that annotations and references get correctly bundled when
+ """Test that annotations, references, and threads get correctly bundled when
getting the parent event.
"""
@@ -410,6 +413,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
reply_2 = channel.json_body["event_id"]
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_2 = channel.json_body["event_id"]
+
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
@@ -429,6 +439,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
RelationTypes.REFERENCE: {
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
},
+ RelationTypes.THREAD: {
+ "count": 2,
+ "latest_event": {
+ "age": 100,
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "origin_server_ts": 1600,
+ "room_id": self.room,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ "unsigned": {"age": 100},
+ "user_id": self.user_id,
+ },
+ },
},
)
@@ -559,7 +588,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{
"m.relates_to": {
"event_id": self.parent_id,
- "key": None,
"rel_type": "m.reference",
}
},
@@ -677,24 +705,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def _send_relation(
self,
- relation_type,
- event_type,
- key=None,
+ relation_type: str,
+ event_type: str,
+ key: Optional[str] = None,
content: Optional[dict] = None,
- access_token=None,
- parent_id=None,
- ):
+ access_token: Optional[str] = None,
+ parent_id: Optional[str] = None,
+ ) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
Args:
- relation_type (str): One of `RelationTypes`
- event_type (str): The type of the event to create
- parent_id (str): The event_id this relation relates to. If None, then self.parent_id
- key (str|None): The aggregation key used for m.annotation relation
- type.
- content(dict|None): The content of the created event.
- access_token (str|None): The access token used to send the relation,
- defaults to `self.user_token`
+ relation_type: One of `RelationTypes`
+ event_type: The type of the event to create
+ key: The aggregation key used for m.annotation relation type.
+ content: The content of the created event.
+ access_token: The access token used to send the relation, defaults
+ to `self.user_token`
+ parent_id: The event_id this relation relates to. If None, then self.parent_id
Returns:
FakeChannel
@@ -712,12 +739,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
- json.dumps(content or {}).encode("utf-8"),
+ content or {},
access_token=access_token,
)
return channel
- def _create_user(self, localpart):
+ def _create_user(self, localpart: str) -> Tuple[str, str]:
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")
|