summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-07-04 18:08:56 +0100
committerGitHub <noreply@github.com>2022-07-04 18:08:56 +0100
commitd102ad67fddc650c34baa89dc7b2926d46a9aeca (patch)
tree542df6d58640540ea934cd7a5bd0e1486fe96bdf
parentRevert "Up the dependency on canonicaljson to ^1.5.0" (diff)
downloadsynapse-d102ad67fddc650c34baa89dc7b2926d46a9aeca.tar.xz
annotate tests.server.FakeChannel (#13136)
-rw-r--r--changelog.d/13136.misc1
-rw-r--r--tests/rest/admin/test_room.py4
-rw-r--r--tests/rest/admin/test_user.py2
-rw-r--r--tests/rest/client/test_account.py5
-rw-r--r--tests/rest/client/test_profile.py10
-rw-r--r--tests/rest/client/test_relations.py2
-rw-r--r--tests/server.py38
7 files changed, 36 insertions, 26 deletions
diff --git a/changelog.d/13136.misc b/changelog.d/13136.misc
new file mode 100644
index 0000000000..6cf451d8cf
--- /dev/null
+++ b/changelog.d/13136.misc
@@ -0,0 +1 @@
+Add type annotations to `tests.server`.
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ca6af9417b..230dc76f72 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1579,8 +1579,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
             access_token=self.admin_user_tok,
         )
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
-        self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
-        self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
+        self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
+        self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
 
     def test_single_room(self) -> None:
         """Test that a single room can be requested correctly"""
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0d44102237..e32aaadb98 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1488,7 +1488,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         if channel.code != HTTPStatus.OK:
             raise HttpResponseException(
-                channel.code, channel.result["reason"], channel.json_body
+                channel.code, channel.result["reason"], channel.result["body"]
             )
 
         # Set monthly active users to the limit
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..1f9b65351e 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -949,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         client_secret: str,
         next_link: Optional[str] = None,
         expect_code: int = 200,
-    ) -> str:
+    ) -> Optional[str]:
         """Request a validation token to add an email address to a user's account
 
         Args:
@@ -959,7 +959,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             expect_code: Expected return code of the call
 
         Returns:
-            The ID of the new threepid validation session
+            The ID of the new threepid validation session, or None if the response
+            did not contain a session ID.
         """
         body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
         if next_link:
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 29bed0e872..8de5a342ae 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -153,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-    def _get_displayname(self, name: Optional[str] = None) -> str:
+    def _get_displayname(self, name: Optional[str] = None) -> Optional[str]:
         channel = self.make_request(
             "GET", "/profile/%s/displayname" % (name or self.owner,)
         )
         self.assertEqual(channel.code, 200, channel.result)
-        return channel.json_body["displayname"]
+        # FIXME: If a user has no displayname set, Synapse returns 200 and omits a
+        # displayname from the response. This contradicts the spec, see #13137.
+        return channel.json_body.get("displayname")
 
-    def _get_avatar_url(self, name: Optional[str] = None) -> str:
+    def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]:
         channel = self.make_request(
             "GET", "/profile/%s/avatar_url" % (name or self.owner,)
         )
         self.assertEqual(channel.code, 200, channel.result)
+        # FIXME: If a user has no avatar set, Synapse returns 200 and omits an
+        # avatar_url from the response. This contradicts the spec, see #13137.
         return channel.json_body.get("avatar_url")
 
     @unittest.override_config({"max_avatar_size": 50})
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index aa84906548..ad03eee17b 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -800,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
             )
             expected_event_ids.append(channel.json_body["event_id"])
 
-        prev_token = ""
+        prev_token: Optional[str] = ""
         found_event_ids: List[str] = []
         for _ in range(20):
             from_token = ""
diff --git a/tests/server.py b/tests/server.py
index ce017ca0f6..df3f1564c9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import (
     IAddress,
+    IConsumer,
     IHostnameResolver,
     IProtocol,
     IPullProducer,
@@ -53,11 +54,7 @@ from twisted.internet.interfaces import (
     ITransport,
 )
 from twisted.python.failure import Failure
-from twisted.test.proto_helpers import (
-    AccumulatingProtocol,
-    MemoryReactor,
-    MemoryReactorClock,
-)
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
@@ -96,6 +93,7 @@ class TimedOutException(Exception):
     """
 
 
+@implementer(IConsumer)
 @attr.s(auto_attribs=True)
 class FakeChannel:
     """
@@ -104,7 +102,7 @@ class FakeChannel:
     """
 
     site: Union[Site, "FakeSite"]
-    _reactor: MemoryReactor
+    _reactor: MemoryReactorClock
     result: dict = attr.Factory(dict)
     _ip: str = "127.0.0.1"
     _producer: Optional[Union[IPullProducer, IPushProducer]] = None
@@ -122,7 +120,7 @@ class FakeChannel:
         self._request = request
 
     @property
-    def json_body(self):
+    def json_body(self) -> JsonDict:
         return json.loads(self.text_body)
 
     @property
@@ -140,7 +138,7 @@ class FakeChannel:
         return self.result.get("done", False)
 
     @property
-    def code(self):
+    def code(self) -> int:
         if not self.result:
             raise Exception("No result yet.")
         return int(self.result["code"])
@@ -160,7 +158,7 @@ class FakeChannel:
         self.result["reason"] = reason
         self.result["headers"] = headers
 
-    def write(self, content):
+    def write(self, content: bytes) -> None:
         assert isinstance(content, bytes), "Should be bytes! " + repr(content)
 
         if "body" not in self.result:
@@ -168,11 +166,16 @@ class FakeChannel:
 
         self.result["body"] += content
 
-    def registerProducer(self, producer, streaming):
+    # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
+    def registerProducer(  # type: ignore[override]
+        self,
+        producer: Union[IPullProducer, IPushProducer],
+        streaming: bool,
+    ) -> None:
         self._producer = producer
         self.producerStreaming = streaming
 
-        def _produce():
+        def _produce() -> None:
             if self._producer:
                 self._producer.resumeProducing()
                 self._reactor.callLater(0.1, _produce)
@@ -180,31 +183,32 @@ class FakeChannel:
         if not streaming:
             self._reactor.callLater(0.0, _produce)
 
-    def unregisterProducer(self):
+    def unregisterProducer(self) -> None:
         if self._producer is None:
             return
 
         self._producer = None
 
-    def requestDone(self, _self):
+    def requestDone(self, _self: Request) -> None:
         self.result["done"] = True
         if isinstance(_self, SynapseRequest):
+            assert _self.logcontext is not None
             self.resource_usage = _self.logcontext.get_resource_usage()
 
-    def getPeer(self):
+    def getPeer(self) -> IAddress:
         # We give an address so that getClientAddress/getClientIP returns a non null entry,
         # causing us to record the MAU
         return address.IPv4Address("TCP", self._ip, 3423)
 
-    def getHost(self):
+    def getHost(self) -> IAddress:
         # this is called by Request.__init__ to configure Request.host.
         return address.IPv4Address("TCP", "127.0.0.1", 8888)
 
-    def isSecure(self):
+    def isSecure(self) -> bool:
         return False
 
     @property
-    def transport(self):
+    def transport(self) -> "FakeChannel":
         return self
 
     def await_result(self, timeout_ms: int = 1000) -> None: