summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_events.py20
-rw-r--r--tests/rest/client/test_groups.py2
-rw-r--r--tests/rest/client/test_register.py110
3 files changed, 71 insertions, 61 deletions
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 145f247836..1b1392fa2f 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -16,8 +16,12 @@
 
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import events, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["enable_registration_captcha"] = False
@@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(config=config)
 
-        hs.get_federation_handler = Mock()
+        hs.get_federation_handler = Mock()  # type: ignore[assignment]
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         # register an account
         self.user_id = self.register_user("sid1", "pass")
@@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         self.other_user = self.register_user("other2", "pass")
         self.other_token = self.login(self.other_user, "pass")
 
-    def test_stream_basic_permissions(self):
+    def test_stream_basic_permissions(self) -> None:
         # invalid token, expect 401
         # note: this is in violation of the original v1 spec, which expected
         # 403. However, since the v1 spec no longer exists and the v1
@@ -76,7 +80,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         self.assertTrue("start" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
-    def test_stream_room_permissions(self):
+    def test_stream_room_permissions(self) -> None:
         room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
         self.helper.send(room_id, tok=self.other_token)
 
@@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         # left to room (expect no content for room)
 
-    def TODO_test_stream_items(self):
+    def TODO_test_stream_items(self) -> None:
         # new user, no content
 
         # join room, expect 1 item (join)
@@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, hs, reactor, clock):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         # register an account
         self.user_id = self.register_user("sid1", "pass")
@@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
 
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
-    def test_get_event_via_events(self):
+    def test_get_event_via_events(self) -> None:
         resp = self.helper.send(self.room_id, tok=self.token)
         event_id = resp["event_id"]
 
diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py
index c99f54cf4f..e067cf825c 100644
--- a/tests/rest/client/test_groups.py
+++ b/tests/rest/client/test_groups.py
@@ -25,7 +25,7 @@ class GroupsTestCase(unittest.HomeserverTestCase):
     servlets = [room.register_servlets, groups.register_servlets]
 
     @override_config({"enable_group_creation": True})
-    def test_rooms_limited_by_visibility(self):
+    def test_rooms_limited_by_visibility(self) -> None:
         group_id = "+spqr:test"
 
         # Alice creates a group
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 4b95b8541c..9aebf1735a 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -16,15 +16,21 @@
 import datetime
 import json
 import os
+from typing import Any, Dict, List, Tuple
 
 import pkg_resources
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.server import HomeServer
 from synapse.storage._base import db_to_json
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
@@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     ]
     url = b"/_matrix/client/r0/register"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["allow_guest_access"] = True
         return config
 
-    def test_POST_appservice_registration_valid(self):
+    def test_POST_appservice_registration_valid(self) -> None:
         user_id = "@as_user_kermit:test"
         as_token = "i_am_an_app_service"
 
@@ -69,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         det_data = {"user_id": user_id, "home_server": self.hs.hostname}
         self.assertDictContainsSubset(det_data, channel.json_body)
 
-    def test_POST_appservice_registration_no_type(self):
+    def test_POST_appservice_registration_no_type(self) -> None:
         as_token = "i_am_an_app_service"
 
         appservice = ApplicationService(
@@ -89,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.result["code"], b"400", channel.result)
 
-    def test_POST_appservice_registration_invalid(self):
+    def test_POST_appservice_registration_invalid(self) -> None:
         self.appservice = None  # no application service exists
         request_data = json.dumps(
             {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
@@ -100,21 +106,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.result["code"], b"401", channel.result)
 
-    def test_POST_bad_password(self):
+    def test_POST_bad_password(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": 666})
         channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEqual(channel.result["code"], b"400", channel.result)
         self.assertEqual(channel.json_body["error"], "Invalid password")
 
-    def test_POST_bad_username(self):
+    def test_POST_bad_username(self) -> None:
         request_data = json.dumps({"username": 777, "password": "monkey"})
         channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEqual(channel.result["code"], b"400", channel.result)
         self.assertEqual(channel.json_body["error"], "Invalid username")
 
-    def test_POST_user_valid(self):
+    def test_POST_user_valid(self) -> None:
         user_id = "@kermit:test"
         device_id = "frogfone"
         params = {
@@ -135,7 +141,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertDictContainsSubset(det_data, channel.json_body)
 
     @override_config({"enable_registration": False})
-    def test_POST_disabled_registration(self):
+    def test_POST_disabled_registration(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
 
@@ -145,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["error"], "Registration has been disabled")
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
-    def test_POST_guest_registration(self):
+    def test_POST_guest_registration(self) -> None:
         self.hs.config.key.macaroon_secret_key = "test"
         self.hs.config.registration.allow_guest_access = True
 
@@ -155,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
-    def test_POST_disabled_guest_registration(self):
+    def test_POST_disabled_guest_registration(self) -> None:
         self.hs.config.registration.allow_guest_access = False
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -164,7 +170,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["error"], "Guest access is disabled")
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
-    def test_POST_ratelimiting_guest(self):
+    def test_POST_ratelimiting_guest(self) -> None:
         for i in range(0, 6):
             url = self.url + b"?kind=guest"
             channel = self.make_request(b"POST", url, b"{}")
@@ -182,7 +188,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
-    def test_POST_ratelimiting(self):
+    def test_POST_ratelimiting(self) -> None:
         for i in range(0, 6):
             params = {
                 "username": "kermit" + str(i),
@@ -206,7 +212,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_requires_token(self):
+    def test_POST_registration_requires_token(self) -> None:
         username = "kermit"
         device_id = "frogfone"
         token = "abcd"
@@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
-        params = {
+        params: JsonDict = {
             "username": username,
             "password": "monkey",
             "device_id": device_id,
@@ -280,8 +286,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res["pending"], 0)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_invalid(self):
-        params = {
+    def test_POST_registration_token_invalid(self) -> None:
+        params: JsonDict = {
             "username": "kermit",
             "password": "monkey",
         }
@@ -314,7 +320,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["completed"], [])
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_limit_uses(self):
+    def test_POST_registration_token_limit_uses(self) -> None:
         token = "abcd"
         store = self.hs.get_datastores().main
         # Create token that can be used once
@@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
-        params1 = {"username": "bert", "password": "monkey"}
-        params2 = {"username": "ernie", "password": "monkey"}
+        params1: JsonDict = {"username": "bert", "password": "monkey"}
+        params2: JsonDict = {"username": "ernie", "password": "monkey"}
         # Do 2 requests without auth to get two session IDs
         channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
         session1 = channel1.json_body["session"]
@@ -388,7 +394,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["completed"], [])
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_expiry(self):
+    def test_POST_registration_token_expiry(self) -> None:
         token = "abcd"
         now = self.hs.get_clock().time_msec()
         store = self.hs.get_datastores().main
@@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
-        params = {"username": "kermit", "password": "monkey"}
+        params: JsonDict = {"username": "kermit", "password": "monkey"}
         # Request without auth to get session
         channel = self.make_request(b"POST", self.url, json.dumps(params))
         session = channel.json_body["session"]
@@ -436,7 +442,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_session_expiry(self):
+    def test_POST_registration_token_session_expiry(self) -> None:
         """Test `pending` is decremented when an uncompleted session expires."""
         token = "abcd"
         store = self.hs.get_datastores().main
@@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         )
 
         # Do 2 requests without auth to get two session IDs
-        params1 = {"username": "bert", "password": "monkey"}
-        params2 = {"username": "ernie", "password": "monkey"}
+        params1: JsonDict = {"username": "bert", "password": "monkey"}
+        params2: JsonDict = {"username": "ernie", "password": "monkey"}
         channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
         session1 = channel1.json_body["session"]
         channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
@@ -522,7 +528,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(pending, 0)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_session_expiry_deleted_token(self):
+    def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
         """Test session expiry doesn't break when the token is deleted.
 
         1. Start but don't complete UIA with a registration token
@@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         )
 
         # Do request without auth to get a session ID
-        params = {"username": "kermit", "password": "monkey"}
+        params: JsonDict = {"username": "kermit", "password": "monkey"}
         channel = self.make_request(b"POST", self.url, json.dumps(params))
         session = channel.json_body["session"]
 
@@ -570,7 +576,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
         )
 
-    def test_advertised_flows(self):
+    def test_advertised_flows(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
@@ -593,7 +599,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_advertised_flows_captcha_and_terms_and_3pids(self):
+    def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
@@ -625,7 +631,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_advertised_flows_no_msisdn_email_required(self):
+    def test_advertised_flows_no_msisdn_email_required(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
@@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_request_token_existing_email_inhibit_error(self):
+    def test_request_token_existing_email_inhibit_error(self) -> None:
         """Test that requesting a token via this endpoint doesn't leak existing
         associations if configured that way.
         """
@@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_reject_invalid_email(self):
+    def test_reject_invalid_email(self) -> None:
         """Check that bad emails are rejected"""
 
         # Test for email with multiple @
@@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "inhibit_user_in_use_error": True,
         }
     )
-    def test_inhibit_user_in_use_error(self):
+    def test_inhibit_user_in_use_error(self) -> None:
         """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
         correctly.
         """
@@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         account_validity.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         # Test for account expiring after a week.
         config["enable_registration"] = True
@@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def test_validity_period(self):
+    def test_validity_period(self) -> None:
         self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -810,7 +816,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
         )
 
-    def test_manual_renewal(self):
+    def test_manual_renewal(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -833,7 +839,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEqual(channel.result["code"], b"200", channel.result)
 
-    def test_manual_expire(self):
+    def test_manual_expire(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -858,7 +864,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
         )
 
-    def test_logging_out_expired_user(self):
+    def test_logging_out_expired_user(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         # Test for account expiring after a week and renewal emails being sent 2
@@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
         self.hs = self.setup_test_homeserver(config=config)
 
-        async def sendmail(*args, **kwargs):
+        async def sendmail(*args: Any, **kwargs: Any) -> None:
             self.email_attempts.append((args, kwargs))
 
-        self.email_attempts = []
+        self.email_attempts: List[Tuple[Any, Any]] = []
         self.hs.get_send_email_handler()._sendmail = sendmail
 
         self.store = self.hs.get_datastores().main
 
         return self.hs
 
-    def test_renewal_email(self):
+    def test_renewal_email(self) -> None:
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
@@ -999,7 +1005,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEqual(channel.result["code"], b"200", channel.result)
 
-    def test_renewal_invalid_token(self):
+    def test_renewal_invalid_token(self) -> None:
         # Hit the renewal endpoint with an invalid token and check that it behaves as
         # expected, i.e. that it responds with 404 Not Found and the correct HTML.
         url = "/_matrix/client/unstable/account_validity/renew?token=123"
@@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             channel.result["body"], expected_html.encode("utf8"), channel.result
         )
 
-    def test_manual_email_send(self):
+    def test_manual_email_send(self) -> None:
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
@@ -1032,7 +1038,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(len(self.email_attempts), 1)
 
-    def test_deactivated_user(self):
+    def test_deactivated_user(self) -> None:
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
@@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(len(self.email_attempts), 0)
 
-    def create_user(self):
+    def create_user(self) -> Tuple[str, str]:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
         # We need to manually add an email address otherwise the handler will do
@@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         )
         return user_id, tok
 
-    def test_manual_email_send_expired_account(self):
+    def test_manual_email_send_expired_account(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
     servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.validity_period = 10
         self.max_delta = self.validity_period * 10.0 / 100.0
 
@@ -1135,7 +1141,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def test_background_job(self):
+    def test_background_job(self) -> None:
         """
         Tests the same thing as test_background_job, except that it sets the
         startup_job_max_delta parameter and checks that the expiration date is within the
@@ -1158,12 +1164,12 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
     servlets = [register.register_servlets]
     url = "/_matrix/client/v1/register/m.login.registration_token/validity"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["registration_requires_token"] = True
         return config
 
-    def test_GET_token_valid(self):
+    def test_GET_token_valid(self) -> None:
         token = "abcd"
         store = self.hs.get_datastores().main
         self.get_success(
@@ -1186,7 +1192,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["valid"], True)
 
-    def test_GET_token_invalid(self):
+    def test_GET_token_invalid(self) -> None:
         token = "1234"
         channel = self.make_request(
             b"GET",
@@ -1198,7 +1204,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
     @override_config(
         {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
     )
-    def test_GET_ratelimiting(self):
+    def test_GET_ratelimiting(self) -> None:
         token = "1234"
 
         for i in range(0, 6):