summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-14 09:19:35 -0400
committerGitHub <noreply@github.com>2021-10-14 09:19:35 -0400
commit1609ccf8fec87a941d3c27f668f6dca8f75a3f4a (patch)
treee3d4451d06b3270d0b33ae55886690c3ca653400
parentAdd a test for a workaround concerning the behaviour of third-party rule modu... (diff)
downloadsynapse-1609ccf8fec87a941d3c27f668f6dca8f75a3f4a.tar.xz
Fix-up some type hints in the relations tests. (#11076)
-rw-r--r--changelog.d/11076.misc1
-rw-r--r--mypy.ini1
-rw-r--r--tests/rest/client/test_relations.py55
-rw-r--r--tests/server.py54
-rw-r--r--tests/unittest.py4
5 files changed, 64 insertions, 51 deletions
diff --git a/changelog.d/11076.misc b/changelog.d/11076.misc
new file mode 100644
index 0000000000..c581a86e47
--- /dev/null
+++ b/changelog.d/11076.misc
@@ -0,0 +1 @@
+Fix type hints in the relations tests.
diff --git a/mypy.ini b/mypy.ini
index 2cdd552f46..cb4489eb37 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -92,6 +92,7 @@ files =
   tests/handlers/test_user_directory.py,
   tests/rest/client/test_login.py,
   tests/rest/client/test_auth.py,
+  tests/rest/client/test_relations.py,
   tests/rest/media/v1/test_filepath.py,
   tests/rest/media/v1/test_oembed.py,
   tests/storage/test_state.py,
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..3c7d49f0b4 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,15 +13,15 @@
 # limitations under the License.
 
 import itertools
-import json
-import urllib
-from typing import Optional
+import urllib.parse
+from typing import Dict, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room
 
 from tests import unittest
+from tests.server import FakeChannel
 
 
 class RelationsTestCase(unittest.HomeserverTestCase):
@@ -34,16 +34,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def make_homeserver(self, reactor, clock):
+    def default_config(self) -> dict:
         # We need to enable msc1849 support for aggregations
-        config = self.default_config()
+        config = super().default_config()
         config["experimental_msc1849_support_enabled"] = True
 
         # We enable frozen dicts as relations/edits change event contents, so we
         # want to test that we don't modify the events in the caches.
         config["use_frozen_dicts"] = True
 
-        return self.setup_test_homeserver(config=config)
+        return config
 
     def prepare(self, reactor, clock, hs):
         self.user_id, self.user_token = self._create_user("alice")
@@ -146,8 +146,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             self.assertEquals(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
-        prev_token = None
-        found_event_ids = []
+        prev_token: Optional[str] = None
+        found_event_ids: List[str] = []
         for _ in range(20):
             from_token = ""
             if prev_token:
@@ -203,8 +203,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             idx += 1
             idx %= len(access_tokens)
 
-        prev_token = None
-        found_groups = {}
+        prev_token: Optional[str] = None
+        found_groups: Dict[str, int] = {}
         for _ in range(20):
             from_token = ""
             if prev_token:
@@ -270,8 +270,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        prev_token = None
-        found_event_ids = []
+        prev_token: Optional[str] = None
+        found_event_ids: List[str] = []
         encoded_key = urllib.parse.quote_plus("👍".encode())
         for _ in range(20):
             from_token = ""
@@ -677,24 +677,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def _send_relation(
         self,
-        relation_type,
-        event_type,
-        key=None,
+        relation_type: str,
+        event_type: str,
+        key: Optional[str] = None,
         content: Optional[dict] = None,
-        access_token=None,
-        parent_id=None,
-    ):
+        access_token: Optional[str] = None,
+        parent_id: Optional[str] = None,
+    ) -> FakeChannel:
         """Helper function to send a relation pointing at `self.parent_id`
 
         Args:
-            relation_type (str): One of `RelationTypes`
-            event_type (str): The type of the event to create
-            parent_id (str): The event_id this relation relates to. If None, then self.parent_id
-            key (str|None): The aggregation key used for m.annotation relation
-                type.
-            content(dict|None): The content of the created event.
-            access_token (str|None): The access token used to send the relation,
-                defaults to `self.user_token`
+            relation_type: One of `RelationTypes`
+            event_type: The type of the event to create
+            key: The aggregation key used for m.annotation relation type.
+            content: The content of the created event.
+            access_token: The access token used to send the relation, defaults
+                to `self.user_token`
+            parent_id: The event_id this relation relates to. If None, then self.parent_id
 
         Returns:
             FakeChannel
@@ -712,12 +711,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
             % (self.room, original_id, relation_type, event_type, query),
-            json.dumps(content or {}).encode("utf-8"),
+            content or {},
             access_token=access_token,
         )
         return channel
 
-    def _create_user(self, localpart):
+    def _create_user(self, localpart: str) -> Tuple[str, str]:
         user_id = self.register_user(localpart, "abc123")
         access_token = self.login(localpart, "abc123")
 
diff --git a/tests/server.py b/tests/server.py
index 64645651ce..103351b487 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,3 +1,17 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import json
 import logging
 from collections import deque
@@ -27,9 +41,10 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
-from twisted.web.server import Site
+from twisted.web.server import Request, Site
 
 from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests.utils import setup_test_homeserver as _sth
@@ -198,14 +213,14 @@ class FakeSite:
 def make_request(
     reactor,
     site: Union[Site, FakeSite],
-    method,
-    path,
-    content=b"",
-    access_token=None,
-    request=SynapseRequest,
-    shorthand=True,
-    federation_auth_origin=None,
-    content_is_form=False,
+    method: Union[bytes, str],
+    path: Union[bytes, str],
+    content: Union[bytes, str, JsonDict] = b"",
+    access_token: Optional[str] = None,
+    request: Request = SynapseRequest,
+    shorthand: bool = True,
+    federation_auth_origin: Optional[bytes] = None,
+    content_is_form: bool = False,
     await_result: bool = True,
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
@@ -218,26 +233,23 @@ def make_request(
     Returns the fake Channel object which records the response to the request.
 
     Args:
+        reactor:
         site: The twisted Site to use to render the request
-
-        method (bytes/unicode): The HTTP request method ("verb").
-        path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
-        escaped UTF-8 & spaces and such).
-        content (bytes or dict): The body of the request. JSON-encoded, if
-        a dict.
+        method: The HTTP request method ("verb").
+        path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
+        content: The body of the request. JSON-encoded, if a str of bytes.
+        access_token: The access token to add as authorization for the request.
+        request: The request class to create.
         shorthand: Whether to try and be helpful and prefix the given URL
-        with the usual REST API path, if it doesn't contain it.
-        federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+            with the usual REST API path, if it doesn't contain it.
+        federation_auth_origin: if set to not-None, we will add a fake
             Authorization header pretenting to be the given server name.
         content_is_form: Whether the content is URL encoded form data. Adds the
             'Content-Type': 'application/x-www-form-urlencoded' header.
-
-        custom_headers: (name, value) pairs to add as request headers
-
         await_result: whether to wait for the request to complete rendering. If true,
              will pump the reactor until the the renderer tells the channel the request
              is finished.
-
+        custom_headers: (name, value) pairs to add as request headers
         client_ip: The IP to use as the requesting IP. Useful for testing
             ratelimiting.
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 81c1a9e9d2..a9b60b7eeb 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -46,7 +46,7 @@ from synapse.logging.context import (
     set_current_context,
 )
 from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -401,7 +401,7 @@ class HomeserverTestCase(TestCase):
         self,
         method: Union[bytes, str],
         path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
+        content: Union[bytes, str, JsonDict] = b"",
         access_token: Optional[str] = None,
         request: Type[T] = SynapseRequest,
         shorthand: bool = True,