summary refs log tree commit diff
path: root/tests/rest/client/test_relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_relations.py')
-rw-r--r--tests/rest/client/test_relations.py58
1 files changed, 31 insertions, 27 deletions
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 8f7181103b..c8db45719e 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -18,11 +18,15 @@ import urllib.parse
 from typing import Dict, List, Optional, Tuple
 from unittest.mock import patch
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
+from synapse.server import HomeServer
 from synapse.storage.relations import RelationPaginationToken
 from synapse.types import JsonDict, StreamToken
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -52,7 +56,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
         self.user_id, self.user_token = self._create_user("alice")
@@ -63,7 +67,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
         self.parent_id = res["event_id"]
 
-    def test_send_relation(self):
+    def test_send_relation(self) -> None:
         """Tests that sending a relation using the new /send_relation works
         creates the right shape of event.
         """
@@ -95,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel.json_body,
         )
 
-    def test_deny_invalid_event(self):
+    def test_deny_invalid_event(self) -> None:
         """Test that we deny relations on non-existant events"""
         channel = self._send_relation(
             RelationTypes.ANNOTATION,
@@ -125,7 +129,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
 
-    def test_deny_invalid_room(self):
+    def test_deny_invalid_room(self) -> None:
         """Test that we deny relations on non-existant events"""
         # Create another room and send a message in it.
         room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
@@ -138,7 +142,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(400, channel.code, channel.json_body)
 
-    def test_deny_double_react(self):
+    def test_deny_double_react(self) -> None:
         """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEqual(200, channel.code, channel.json_body)
@@ -146,7 +150,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEqual(400, channel.code, channel.json_body)
 
-    def test_deny_forked_thread(self):
+    def test_deny_forked_thread(self) -> None:
         """It is invalid to start a thread off a thread."""
         channel = self._send_relation(
             RelationTypes.THREAD,
@@ -165,7 +169,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(400, channel.code, channel.json_body)
 
-    def test_basic_paginate_relations(self):
+    def test_basic_paginate_relations(self) -> None:
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEqual(200, channel.code, channel.json_body)
@@ -235,7 +239,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             ).to_string(self.store)
         )
 
-    def test_repeated_paginate_relations(self):
+    def test_repeated_paginate_relations(self) -> None:
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
         """
@@ -303,7 +307,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         found_event_ids.reverse()
         self.assertEqual(found_event_ids, expected_event_ids)
 
-    def test_pagination_from_sync_and_messages(self):
+    def test_pagination_from_sync_and_messages(self) -> None:
         """Pagination tokens from /sync and /messages can be used to paginate /relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
         self.assertEqual(200, channel.code, channel.json_body)
@@ -362,7 +366,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
             )
 
-    def test_aggregation_pagination_groups(self):
+    def test_aggregation_pagination_groups(self) -> None:
         """Test that we can paginate annotation groups correctly."""
 
         # We need to create ten separate users to send each reaction.
@@ -427,7 +431,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sent_groups, found_groups)
 
-    def test_aggregation_pagination_within_group(self):
+    def test_aggregation_pagination_within_group(self) -> None:
         """Test that we can paginate within an annotation group."""
 
         # We need to create ten separate users to send each reaction.
@@ -524,7 +528,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         found_event_ids.reverse()
         self.assertEqual(found_event_ids, expected_event_ids)
 
-    def test_aggregation(self):
+    def test_aggregation(self) -> None:
         """Test that annotations get correctly aggregated."""
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -556,7 +560,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_aggregation_redactions(self):
+    def test_aggregation_redactions(self) -> None:
         """Test that annotations get correctly aggregated after a redaction."""
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -590,7 +594,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
         )
 
-    def test_aggregation_must_be_annotation(self):
+    def test_aggregation_must_be_annotation(self) -> None:
         """Test that aggregations must be annotations."""
 
         channel = self.make_request(
@@ -604,7 +608,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
     )
-    def test_bundled_aggregations(self):
+    def test_bundled_aggregations(self) -> None:
         """
         Test that annotations, references, and threads get correctly bundled.
 
@@ -746,7 +750,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         ]
         assert_bundle(self._find_event_in_chunk(chunk))
 
-    def test_aggregation_get_event_for_annotation(self):
+    def test_aggregation_get_event_for_annotation(self) -> None:
         """Test that annotations do not get bundled aggregations included
         when directly requested.
         """
@@ -768,7 +772,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
         self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
 
-    def test_aggregation_get_event_for_thread(self):
+    def test_aggregation_get_event_for_thread(self) -> None:
         """Test that threads get bundled aggregations included when directly requested."""
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         self.assertEqual(200, channel.code, channel.json_body)
@@ -815,7 +819,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
-    def test_ignore_invalid_room(self):
+    def test_ignore_invalid_room(self) -> None:
         """Test that we ignore invalid relations over federation."""
         # Create another room and send a message in it.
         room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
@@ -927,7 +931,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("m.relations", channel.json_body["unsigned"])
 
     @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
-    def test_edit(self):
+    def test_edit(self) -> None:
         """Test that a simple edit works."""
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1010,7 +1014,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         ]
         assert_bundle(self._find_event_in_chunk(chunk))
 
-    def test_multi_edit(self):
+    def test_multi_edit(self) -> None:
         """Test that multiple edits, including attempts by people who
         shouldn't be allowed, are correctly handled.
         """
@@ -1067,7 +1071,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
-    def test_edit_reply(self):
+    def test_edit_reply(self) -> None:
         """Test that editing a reply works."""
 
         # Create a reply to edit.
@@ -1124,7 +1128,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
-    def test_edit_thread(self):
+    def test_edit_thread(self) -> None:
         """Test that editing a thread works."""
 
         # Create a thread and edit the last event.
@@ -1163,7 +1167,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         latest_event_in_thread = thread_summary["latest_event"]
         self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
 
-    def test_edit_edit(self):
+    def test_edit_edit(self) -> None:
         """Test that an edit cannot be edited."""
         new_body = {"msgtype": "m.text", "body": "Initial edit"}
         channel = self._send_relation(
@@ -1213,7 +1217,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
-    def test_relations_redaction_redacts_edits(self):
+    def test_relations_redaction_redacts_edits(self) -> None:
         """Test that edits of an event are redacted when the original event
         is redacted.
         """
@@ -1269,7 +1273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertIn("chunk", channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
-    def test_aggregations_redaction_prevents_access_to_aggregations(self):
+    def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None:
         """Test that annotations of an event are redacted when the original event
         is redacted.
         """
@@ -1309,7 +1313,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertIn("chunk", channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
-    def test_unknown_relations(self):
+    def test_unknown_relations(self) -> None:
         """Unknown relations should be accepted."""
         channel = self._send_relation("m.relation.test", "m.room.test")
         self.assertEqual(200, channel.code, channel.json_body)
@@ -1417,7 +1421,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         return user_id, access_token
 
-    def test_background_update(self):
+    def test_background_update(self) -> None:
         """Test the event_arbitrary_relations background update."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
         self.assertEqual(200, channel.code, channel.json_body)