summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_device.py2
-rw-r--r--tests/rest/admin/test_event_reports.py382
-rw-r--r--tests/rest/admin/test_room.py2
-rw-r--r--tests/rest/admin/test_user.py116
-rw-r--r--tests/rest/client/v1/test_login.py134
-rw-r--r--tests/rest/client/v1/test_push_rule_attrs.py448
-rw-r--r--tests/rest/client/v1/test_rooms.py22
-rw-r--r--tests/rest/client/v2_alpha/test_account.py138
-rw-r--r--tests/rest/media/v1/test_media_storage.py39
-rw-r--r--tests/rest/test_well_known.py2
10 files changed, 1255 insertions, 30 deletions
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index faa7f381a9..92c9058887 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.render(request)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
 
         # Ensure the display name was not updated.
         request, channel = self.make_request(
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
new file mode 100644
index 0000000000..bf79086f78
--- /dev/null
+++ b/tests/rest/admin/test_event_reports.py
@@ -0,0 +1,382 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class EventReportsTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        report_event.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.room_id1 = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok, is_public=True
+        )
+        self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+        self.room_id2 = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok, is_public=True
+        )
+        self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
+
+        # Two rooms and two users. Every user sends and reports every room event
+        for i in range(5):
+            self._create_event_and_report(
+                room_id=self.room_id1, user_tok=self.other_user_tok,
+            )
+        for i in range(5):
+            self._create_event_and_report(
+                room_id=self.room_id2, user_tok=self.other_user_tok,
+            )
+        for i in range(5):
+            self._create_event_and_report(
+                room_id=self.room_id1, user_tok=self.admin_user_tok,
+            )
+        for i in range(5):
+            self._create_event_and_report(
+                room_id=self.room_id2, user_tok=self.admin_user_tok,
+            )
+
+        self.url = "/_synapse/admin/v1/event_reports"
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url, access_token=self.other_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_default_success(self):
+        """
+        Testing list of reported events
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url, access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 20)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["event_reports"])
+
+    def test_limit(self):
+        """
+        Testing list of reported events with limit
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 5)
+        self.assertEqual(channel.json_body["next_token"], 5)
+        self._check_fields(channel.json_body["event_reports"])
+
+    def test_from(self):
+        """
+        Testing list of reported events with a defined starting point (from)
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 15)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["event_reports"])
+
+    def test_limit_and_from(self):
+        """
+        Testing list of reported events with a defined starting point and limit
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(channel.json_body["next_token"], 15)
+        self.assertEqual(len(channel.json_body["event_reports"]), 10)
+        self._check_fields(channel.json_body["event_reports"])
+
+    def test_filter_room(self):
+        """
+        Testing list of reported events with a filter of room
+        """
+
+        request, channel = self.make_request(
+            "GET",
+            self.url + "?room_id=%s" % self.room_id1,
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 10)
+        self.assertEqual(len(channel.json_body["event_reports"]), 10)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["event_reports"])
+
+        for report in channel.json_body["event_reports"]:
+            self.assertEqual(report["room_id"], self.room_id1)
+
+    def test_filter_user(self):
+        """
+        Testing list of reported events with a filter of user
+        """
+
+        request, channel = self.make_request(
+            "GET",
+            self.url + "?user_id=%s" % self.other_user,
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 10)
+        self.assertEqual(len(channel.json_body["event_reports"]), 10)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["event_reports"])
+
+        for report in channel.json_body["event_reports"]:
+            self.assertEqual(report["user_id"], self.other_user)
+
+    def test_filter_user_and_room(self):
+        """
+        Testing list of reported events with a filter of user and room
+        """
+
+        request, channel = self.make_request(
+            "GET",
+            self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 5)
+        self.assertEqual(len(channel.json_body["event_reports"]), 5)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["event_reports"])
+
+        for report in channel.json_body["event_reports"]:
+            self.assertEqual(report["user_id"], self.other_user)
+            self.assertEqual(report["room_id"], self.room_id1)
+
+    def test_valid_search_order(self):
+        """
+        Testing search order. Order by timestamps.
+        """
+
+        # fetch the most recent first, largest timestamp
+        request, channel = self.make_request(
+            "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 20)
+        report = 1
+        while report < len(channel.json_body["event_reports"]):
+            self.assertGreaterEqual(
+                channel.json_body["event_reports"][report - 1]["received_ts"],
+                channel.json_body["event_reports"][report]["received_ts"],
+            )
+            report += 1
+
+        # fetch the oldest first, smallest timestamp
+        request, channel = self.make_request(
+            "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 20)
+        report = 1
+        while report < len(channel.json_body["event_reports"]):
+            self.assertLessEqual(
+                channel.json_body["event_reports"][report - 1]["received_ts"],
+                channel.json_body["event_reports"][report]["received_ts"],
+            )
+            report += 1
+
+    def test_invalid_search_order(self):
+        """
+        Testing that a invalid search order returns a 400
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+        self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+
+    def test_limit_is_negative(self):
+        """
+        Testing that a negative list parameter returns a 400
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+    def test_from_is_negative(self):
+        """
+        Testing that a negative from parameter returns a 400
+        """
+
+        request, channel = self.make_request(
+            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+    def test_next_token(self):
+        """
+        Testing that `next_token` appears at the right place
+        """
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        request, channel = self.make_request(
+            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 20)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        request, channel = self.make_request(
+            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 20)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        request, channel = self.make_request(
+            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 19)
+        self.assertEqual(channel.json_body["next_token"], 19)
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        request, channel = self.make_request(
+            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], 20)
+        self.assertEqual(len(channel.json_body["event_reports"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def _create_event_and_report(self, room_id, user_tok):
+        """Create and report events
+        """
+        resp = self.helper.send(room_id, tok=user_tok)
+        event_id = resp["event_id"]
+
+        request, channel = self.make_request(
+            "POST",
+            "rooms/%s/report/%s" % (room_id, event_id),
+            json.dumps({"score": -100, "reason": "this makes me sad"}),
+            access_token=user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+    def _check_fields(self, content):
+        """Checks that all attributes are present in a event report
+        """
+        for c in content:
+            self.assertIn("id", c)
+            self.assertIn("received_ts", c)
+            self.assertIn("room_id", c)
+            self.assertIn("event_id", c)
+            self.assertIn("user_id", c)
+            self.assertIn("reason", c)
+            self.assertIn("content", c)
+            self.assertIn("sender", c)
+            self.assertIn("room_alias", c)
+            self.assertIn("event_json", c)
+            self.assertIn("score", c["content"])
+            self.assertIn("reason", c["content"])
+            self.assertIn("auth_events", c["event_json"])
+            self.assertIn("type", c["event_json"])
+            self.assertIn("room_id", c["event_json"])
+            self.assertIn("sender", c["event_json"])
+            self.assertIn("content", c["event_json"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 408c568a27..6dfc709dc5 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1174,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         self.assertIn("room_id", channel.json_body)
         self.assertIn("name", channel.json_body)
+        self.assertIn("topic", channel.json_body)
+        self.assertIn("avatar", channel.json_body)
         self.assertIn("canonical_alias", channel.json_body)
         self.assertIn("joined_members", channel.json_body)
         self.assertIn("joined_local_members", channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 160c630235..98d0623734 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -22,8 +22,8 @@ from mock import Mock
 
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
-from synapse.api.errors import HttpResponseException, ResourceLimitError
-from synapse.rest.client.v1 import login
+from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import sync
 
 from tests import unittest
@@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         )
         self.render(request)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self._is_erased("@user:test", False)
+        d = self.store.mark_user_erased("@user:test")
+        self.assertIsNone(self.get_success(d))
+        self._is_erased("@user:test", True)
 
         # Attempt to reactivate the user (without a password).
         request, channel = self.make_request(
@@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self._is_erased("@user:test", False)
 
     def test_set_user_as_admin(self):
         """
@@ -995,3 +1000,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Ensure they're still alive
         self.assertEqual(0, channel.json_body["deactivated"])
+
+    def _is_erased(self, user_id, expect):
+        """Assert that the user is erased or not
+        """
+        d = self.store.is_user_erased(user_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertFalse(self.get_success(d))
+
+
+class UserMembershipRestTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote(
+            self.other_user
+        )
+
+    def test_no_auth(self):
+        """
+        Try to list rooms of an user without authentication.
+        """
+        request, channel = self.make_request("GET", self.url, b"{}")
+        self.render(request)
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user", "pass")
+
+        request, channel = self.make_request(
+            "GET", self.url, access_token=other_user_token,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_does_not_exist(self):
+        """
+        Tests that a lookup for a user that does not exist returns a 404
+        """
+        url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
+        request, channel = self.make_request(
+            "GET", url, access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that a lookup for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
+
+        request, channel = self.make_request(
+            "GET", url, access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+    def test_get_rooms(self):
+        """
+        Tests that a normal lookup for rooms is successfully
+        """
+        # Create rooms and join
+        other_user_tok = self.login("user", "pass")
+        number_rooms = 5
+        for n in range(number_rooms):
+            self.helper.create_room_as(self.other_user, tok=other_user_tok)
+
+        # Get rooms
+        request, channel = self.make_request(
+            "GET", self.url, access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(number_rooms, channel.json_body["total"])
+        self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2668662c9e..5d987a30c7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -7,8 +7,9 @@ from mock import Mock
 import jwt
 
 import synapse.rest.admin
+from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
-from synapse.rest.client.v2_alpha import devices
+from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
 
 from tests import unittest
@@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
             "JWT validation failed: Signature verification failed",
         )
+
+
+AS_USER = "as_user_alice"
+
+
+class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        register.register_servlets,
+    ]
+
+    def register_as_user(self, username):
+        request, channel = self.make_request(
+            b"POST",
+            "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
+            {"username": username},
+        )
+        self.render(request)
+
+    def make_homeserver(self, reactor, clock):
+        self.hs = self.setup_test_homeserver()
+
+        self.service = ApplicationService(
+            id="unique_identifier",
+            token="some_token",
+            hostname="example.com",
+            sender="@asbot:example.com",
+            namespaces={
+                ApplicationService.NS_USERS: [
+                    {"regex": r"@as_user.*", "exclusive": False}
+                ],
+                ApplicationService.NS_ROOMS: [],
+                ApplicationService.NS_ALIASES: [],
+            },
+        )
+        self.another_service = ApplicationService(
+            id="another__identifier",
+            token="another_token",
+            hostname="example.com",
+            sender="@as2bot:example.com",
+            namespaces={
+                ApplicationService.NS_USERS: [
+                    {"regex": r"@as2_user.*", "exclusive": False}
+                ],
+                ApplicationService.NS_ROOMS: [],
+                ApplicationService.NS_ALIASES: [],
+            },
+        )
+
+        self.hs.get_datastore().services_cache.append(self.service)
+        self.hs.get_datastore().services_cache.append(self.another_service)
+        return self.hs
+
+    def test_login_appservice_user(self):
+        """Test that an appservice user can use /login
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": AS_USER},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_login_appservice_user_bot(self):
+        """Test that the appservice bot can use /login
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": self.service.sender},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_login_appservice_wrong_user(self):
+        """Test that non-as users cannot login with the as token
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_login_appservice_wrong_as(self):
+        """Test that as users cannot login with wrong as token
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": AS_USER},
+        }
+        request, channel = self.make_request(
+            b"POST", LOGIN_URL, params, access_token=self.another_service.token
+        )
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_login_appservice_no_token(self):
+        """Test that users must provide a token when using the appservice
+           login method
+        """
+        self.register_as_user(AS_USER)
+
+        params = {
+            "type": login.LoginRestServlet.APPSERVICE_TYPE,
+            "identifier": {"type": "m.id.user", "user": AS_USER},
+        }
+        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
new file mode 100644
index 0000000000..081052f6a6
--- /dev/null
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, push_rule, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PushRuleAttributesTestCase(HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+        push_rule.register_servlets,
+    ]
+    hijack_auth = False
+
+    def test_enabled_on_creation(self):
+        """
+        Tests the GET and PUT of push rules' `enabled` endpoints.
+        Tests that a rule is enabled upon creation, even though a rule with that
+            ruleId existed previously and was disabled.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET enabled for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], True)
+
+    def test_enabled_on_recreation(self):
+        """
+        Tests the GET and PUT of push rules' `enabled` endpoints.
+        Tests that a rule is enabled upon creation, even if a rule with that
+            ruleId existed previously and was disabled.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # disable the rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": False},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check rule disabled
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], False)
+
+        # DELETE the rule
+        request, channel = self.make_request(
+            "DELETE", "/pushrules/global/override/best.friend", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET enabled for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], True)
+
+    def test_enabled_disable(self):
+        """
+        Tests the GET and PUT of push rules' `enabled` endpoints.
+        Tests that a rule is disabled and enabled when we ask for it.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # disable the rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": False},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check rule disabled
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], False)
+
+        # re-enable the rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": True},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check rule enabled
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], True)
+
+    def test_enabled_404_when_get_non_existent(self):
+        """
+        Tests that `enabled` gives 404 when the rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET enabled for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # DELETE the rule
+        request, channel = self.make_request(
+            "DELETE", "/pushrules/global/override/best.friend", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check 404 for deleted rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_enabled_404_when_get_non_existent_server_rule(self):
+        """
+        Tests that `enabled` gives 404 when the server-default rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_enabled_404_when_put_non_existent_rule(self):
+        """
+        Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": True},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_enabled_404_when_put_non_existent_server_rule(self):
+        """
+        Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/.m.muahahah/enabled",
+            {"enabled": True},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_get(self):
+        """
+        Tests that `actions` gives you what you expect on a fresh rule.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET actions for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
+        )
+
+    def test_actions_put(self):
+        """
+        Tests that PUT on actions updates the value you'd get from GET.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # change the rule actions
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/actions",
+            {"actions": ["dont_notify"]},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET actions for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["actions"], ["dont_notify"])
+
+    def test_actions_404_when_get_non_existent(self):
+        """
+        Tests that `actions` gives 404 when the rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # DELETE the rule
+        request, channel = self.make_request(
+            "DELETE", "/pushrules/global/override/best.friend", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check 404 for deleted rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_404_when_get_non_existent_server_rule(self):
+        """
+        Tests that `actions` gives 404 when the server-default rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_404_when_put_non_existent_rule(self):
+        """
+        Tests that `actions` gives 404 when putting to a rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/actions",
+            {"actions": ["dont_notify"]},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_404_when_put_non_existent_server_rule(self):
+        """
+        Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/.m.muahahah/actions",
+            {"actions": ["dont_notify"]},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0a567b032f..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -905,6 +905,7 @@ class RoomMessageListTestCase(RoomBase):
         first_token = self.get_success(
             store.get_topological_token_for_event(first_event_id)
         )
+        first_token_str = self.get_success(first_token.to_string(store))
 
         # Send a second message in the room, which won't be removed, and which we'll
         # use as the marker to purge events before.
@@ -912,6 +913,7 @@ class RoomMessageListTestCase(RoomBase):
         second_token = self.get_success(
             store.get_topological_token_for_event(second_event_id)
         )
+        second_token_str = self.get_success(second_token.to_string(store))
 
         # Send a third event in the room to ensure we don't fall under any edge case
         # due to our marker being the latest forward extremity in the room.
@@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
         request, channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
-            % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+            % (
+                self.room_id,
+                second_token_str,
+                json.dumps({"types": [EventTypes.Message]}),
+            ),
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
             pagination_handler._purge_history(
                 purge_id=purge_id,
                 room_id=self.room_id,
-                token=second_token,
+                token=second_token_str,
                 delete_local_events=True,
             )
         )
@@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
         request, channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
-            % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+            % (
+                self.room_id,
+                second_token_str,
+                json.dumps({"types": [EventTypes.Message]}),
+            ),
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
         request, channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
-            % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+            % (
+                self.room_id,
+                first_token_str,
+                json.dumps({"types": [EventTypes.Message]}),
+            ),
         )
         self.render(request)
         self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..ae2cd67f35 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import json
 import os
 import re
 from email.parser import Parser
+from typing import Optional
+from urllib.parse import urlencode
 
 import pkg_resources
 
@@ -27,8 +28,10 @@ from synapse.api.constants import LoginType, Membership
 from synapse.api.errors import Codes
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -69,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
     def test_basic_password_reset(self):
         """Test basic password reset flow
@@ -250,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Remove the host
         path = link.replace("https://example.com", "")
 
+        # Load the password reset confirmation page
         request, channel = self.make_request("GET", path, shorthand=False)
-        self.render(request)
+        request.render(self.submit_token_resource)
+        self.pump()
+        self.assertEquals(200, channel.code, channel.result)
+
+        # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+        # password reset confirm button
+
+        # Send arguments as url-encoded form data, matching the template's behaviour
+        form_args = []
+        for key, value_list in request.args.items():
+            for value in value_list:
+                arg = (key, value)
+                form_args.append(arg)
+
+        # Confirm the password reset
+        request, channel = self.make_request(
+            "POST",
+            path,
+            content=urlencode(form_args).encode("utf8"),
+            shorthand=False,
+            content_is_form=True,
+        )
+        request.render(self.submit_token_resource)
+        self.pump()
         self.assertEquals(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
@@ -668,16 +696,110 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
-    def _request_token(self, email, client_secret):
+    @override_config({"next_link_domain_whitelist": None})
+    def test_next_link(self):
+        """Tests a valid next_link parameter value with no whitelist (good case)"""
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.com/a/good/site",
+            expect_code=200,
+        )
+
+    @override_config({"next_link_domain_whitelist": None})
+    def test_next_link_exotic_protocol(self):
+        """Tests using a esoteric protocol as a next_link parameter value.
+        Someone may be hosting a client on IPFS etc.
+        """
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+            expect_code=200,
+        )
+
+    @override_config({"next_link_domain_whitelist": None})
+    def test_next_link_file_uri(self):
+        """Tests next_link parameters cannot be file URI"""
+        # Attempt to use a next_link value that points to the local disk
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="file:///host/path",
+            expect_code=400,
+        )
+
+    @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+    def test_next_link_domain_whitelist(self):
+        """Tests next_link parameters must fit the whitelist if provided"""
+
+        # Ensure not providing a next_link parameter still works
+        self._request_token(
+            "something@example.com", "some_secret", next_link=None, expect_code=200,
+        )
+
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.com/some/good/page",
+            expect_code=200,
+        )
+
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.org/some/also/good/page",
+            expect_code=200,
+        )
+
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://bad.example.org/some/bad/page",
+            expect_code=400,
+        )
+
+    @override_config({"next_link_domain_whitelist": []})
+    def test_empty_next_link_domain_whitelist(self):
+        """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+        disallowed
+        """
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.com/a/page",
+            expect_code=400,
+        )
+
+    def _request_token(
+        self,
+        email: str,
+        client_secret: str,
+        next_link: Optional[str] = None,
+        expect_code: int = 200,
+    ) -> str:
+        """Request a validation token to add an email address to a user's account
+
+        Args:
+            email: The email address to validate
+            client_secret: A secret string
+            next_link: A link to redirect the user to after validation
+            expect_code: Expected return code of the call
+
+        Returns:
+            The ID of the new threepid validation session
+        """
+        body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+        if next_link:
+            body["next_link"] = next_link
+
         request, channel = self.make_request(
-            "POST",
-            b"account/3pid/email/requestToken",
-            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            "POST", b"account/3pid/email/requestToken", body,
         )
         self.render(request)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(expect_code, channel.code, channel.result)
 
-        return channel.json_body["sid"]
+        return channel.json_body.get("sid")
 
     def _request_token_invalid_email(
         self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
     extension = attr.ib(type=bytes)
     expected_cropped = attr.ib(type=Optional[bytes])
     expected_scaled = attr.ib(type=Optional[bytes])
+    expected_found = attr.ib(default=True, type=bool)
 
 
 @parameterized_class(
     ("test_image",),
     [
-        # smol png
+        # smoll png
         (
             _TestImage(
                 unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
                 None,
             ),
         ),
+        # an empty file
+        (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
     def test_thumbnail_crop(self):
-        self._test_thumbnail("crop", self.test_image.expected_cropped)
+        self._test_thumbnail(
+            "crop", self.test_image.expected_cropped, self.test_image.expected_found
+        )
 
     def test_thumbnail_scale(self):
-        self._test_thumbnail("scale", self.test_image.expected_scaled)
+        self._test_thumbnail(
+            "scale", self.test_image.expected_scaled, self.test_image.expected_found
+        )
 
-    def _test_thumbnail(self, method, expected_body):
+    def _test_thumbnail(self, method, expected_body, expected_found):
         params = "?width=32&height=32&method=" + method
         request, channel = self.make_request(
             "GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         self.pump()
 
-        self.assertEqual(channel.code, 200)
-        if expected_body is not None:
+        if expected_found:
+            self.assertEqual(channel.code, 200)
+            if expected_body is not None:
+                self.assertEqual(
+                    channel.result["body"], expected_body, channel.result["body"]
+                )
+            else:
+                # ensure that the result is at least some valid image
+                Image.open(BytesIO(channel.result["body"]))
+        else:
+            # A 404 with a JSON body.
+            self.assertEqual(channel.code, 404)
             self.assertEqual(
-                channel.result["body"], expected_body, channel.result["body"]
+                channel.json_body,
+                {
+                    "errcode": "M_NOT_FOUND",
+                    "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+                    % method,
+                },
             )
-        else:
-            # ensure that the result is at least some valid image
-            Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index b090bb974c..dcd65c2a50 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -21,7 +21,7 @@ from tests import unittest
 
 class WellKnownTests(unittest.HomeserverTestCase):
     def setUp(self):
-        super(WellKnownTests, self).setUp()
+        super().setUp()
 
         # replace the JsonResource with a WellKnownResource
         self.resource = WellKnownResource(self.hs)