diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 82ac5991e6..a8f6436836 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -13,7 +13,6 @@
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -42,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
def test_version_string(self) -> None:
channel = self.make_request("GET", self.url, shorthand=False)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"server_version", "python_version"}, set(channel.json_body.keys())
)
@@ -79,10 +78,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Should be quarantined
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s"
+ "Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
@@ -107,7 +106,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Expect a forbidden error
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg="Expected forbidden on quarantining media as a non-admin",
)
@@ -139,7 +138,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Should be successful
- self.assertEqual(HTTPStatus.OK, channel.code)
+ self.assertEqual(200, channel.code)
# Quarantine the media
url = "/_synapse/admin/v1/media/quarantine/%s/%s" % (
@@ -152,7 +151,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Attempt to access the media
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
@@ -209,7 +208,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
@@ -251,7 +250,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
)
@@ -285,7 +284,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
channel = self.make_request("POST", url, access_token=admin_user_tok)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -297,7 +296,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
access_token=admin_user_tok,
)
self.pump(1.0)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item"
)
@@ -318,10 +317,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Shouldn't be quarantined
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s"
+ "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
@@ -350,7 +349,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
def test_purge_history(self) -> None:
"""
Simple test of purge history API.
- Test only that is is possible to call, get status HTTPStatus.OK and purge_id.
+ Test only that is is possible to call, get status 200 and purge_id.
"""
channel = self.make_request(
@@ -360,7 +359,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("purge_id", channel.json_body)
purge_id = channel.json_body["purge_id"]
@@ -371,5 +370,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 6cf56b1e35..d507a3af8d 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import Collection
from parameterized import parameterized
@@ -51,7 +50,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
self.register_user("user", "pass", admin=False)
@@ -64,7 +63,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -81,7 +80,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# job_name invalid
@@ -92,7 +91,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def _register_bg_update(self) -> None:
@@ -125,7 +124,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running.
self.assertDictEqual(
@@ -147,7 +146,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running.
self.assertDictEqual(
@@ -181,7 +180,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates
@@ -191,7 +190,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in
@@ -204,7 +203,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(
channel.json_body,
{
@@ -231,7 +230,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response.
self.assertDictEqual(
@@ -259,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
@@ -270,7 +269,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress.
self.assertDictEqual(
@@ -325,7 +324,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# test that each background update is waiting now
for update in updates:
@@ -365,4 +364,4 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index f7080bda87..03f2112b07 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
-from http import HTTPStatus
from parameterized import parameterized
@@ -20,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.handlers.device import DeviceHandler
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -35,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -58,7 +60,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(method, self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -76,7 +78,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -85,7 +87,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:test/devices/%s"
@@ -98,13 +100,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s"
@@ -117,12 +119,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_device(self) -> None:
"""
- Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
+ Tests that a lookup for a device that does not exist returns either 404 or 200.
"""
url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote(
self.other_user
@@ -134,7 +136,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
channel = self.make_request(
@@ -143,7 +145,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
channel = self.make_request(
"DELETE",
@@ -151,8 +153,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # Delete unknown device returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown device returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_update_device_too_long_display_name(self) -> None:
"""
@@ -179,7 +181,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content=update,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
@@ -189,12 +191,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_no_display_name(self) -> None:
"""
- Tests that a update for a device without JSON returns a HTTPStatus.OK
+ Tests that a update for a device without JSON returns a 200
"""
# Set iniital display name.
update = {"display_name": "new display"}
@@ -210,7 +212,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure the display name was not updated.
channel = self.make_request(
@@ -219,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"])
def test_update_display_name(self) -> None:
@@ -234,7 +236,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
content={"display_name": "new displayname"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check new display_name
channel = self.make_request(
@@ -243,7 +245,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"])
def test_get_device(self) -> None:
@@ -256,7 +258,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
# Check that all fields are available
self.assertIn("user_id", channel.json_body)
@@ -281,7 +283,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Ensure that the number of devices is decreased
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
@@ -312,7 +314,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -331,7 +333,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -339,7 +341,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
channel = self.make_request(
@@ -348,12 +350,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
@@ -363,7 +365,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_user_has_no_devices(self) -> None:
@@ -379,7 +381,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
@@ -399,7 +401,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
@@ -438,7 +440,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -457,7 +459,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -465,7 +467,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
channel = self.make_request(
@@ -474,12 +476,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
@@ -489,12 +491,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_devices(self) -> None:
"""
- Tests that a remove of a device that does not exist returns HTTPStatus.OK.
+ Tests that a remove of a device that does not exist returns 200.
"""
channel = self.make_request(
"POST",
@@ -503,8 +505,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": ["unknown_device1", "unknown_device2"]},
)
- # Delete unknown devices returns status HTTPStatus.OK
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ # Delete unknown devices returns status 200
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_delete_devices(self) -> None:
"""
@@ -533,7 +535,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
content={"devices": device_ids},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
res = self.get_success(self.handler.get_devices_by_user(self.other_user))
self.assertEqual(0, len(res))
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 4f89f8b534..8a4e5c3f77 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -81,16 +80,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -99,11 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -117,7 +108,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -134,7 +125,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -151,7 +142,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -168,7 +159,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
@@ -185,7 +176,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -205,7 +196,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["event_reports"]), 10)
self.assertNotIn("next_token", channel.json_body)
@@ -225,7 +216,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["event_reports"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -247,7 +238,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -265,7 +256,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
report = 1
@@ -278,7 +269,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
def test_invalid_search_order(self) -> None:
"""
- Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
+ Testing that a invalid search order returns a 400
"""
channel = self.make_request(
@@ -287,17 +278,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
def test_limit_is_negative(self) -> None:
"""
- Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative limit parameter returns a 400
"""
channel = self.make_request(
@@ -306,16 +293,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self) -> None:
"""
- Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
+ Testing that a negative from parameter returns a 400
"""
channel = self.make_request(
@@ -324,11 +307,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -344,7 +323,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -357,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 20)
self.assertNotIn("next_token", channel.json_body)
@@ -370,7 +349,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -384,7 +363,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -400,7 +379,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _create_event_and_report_without_parameters(
self, room_id: str, user_tok: str
@@ -415,7 +394,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
{},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in an event report"""
@@ -431,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertIn("score", c)
self.assertIn("reason", c)
+ def test_count_correct_despite_table_deletions(self) -> None:
+ """
+ Tests that the count matches the number of rows, even if rows in joined tables
+ are missing.
+ """
+
+ # Delete rows from room_stats_state for one of our rooms.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_delete(
+ "room_stats_state", {"room_id": self.room_id1}, desc="_"
+ )
+ )
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ # The 'total' field is 10 because only 10 reports will actually
+ # be retrievable since we deleted the rows in the room_stats_state
+ # table.
+ self.assertEqual(channel.json_body["total"], 10)
+ # This is consistent with the number of rows actually returned.
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+
class EventReportDetailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -466,16 +472,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -484,11 +486,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self) -> None:
@@ -502,12 +500,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_invalid_report_id(self) -> None:
"""
- Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
+ Testing that an invalid `report_id` returns a 400.
"""
# `report_id` is negative
@@ -517,11 +515,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -535,11 +529,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -553,11 +543,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"The report_id parameter must be a string representing a positive integer.",
@@ -566,7 +552,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
def test_report_id_not_found(self) -> None:
"""
- Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
+ Testing that a not existing `report_id` returns a 404.
"""
channel = self.make_request(
@@ -575,11 +561,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"])
@@ -594,7 +576,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
{"score": -100, "reason": "this makes me sad"},
access_token=user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def _check_fields(self, content: JsonDict) -> None:
"""Checks that all attributes are present in a event report"""
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 929bbdc37d..4c7864c629 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from parameterized import parameterized
@@ -64,7 +63,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -77,7 +76,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -87,7 +86,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -97,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -107,7 +106,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -117,7 +116,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
# invalid destination
@@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -142,7 +141,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -160,7 +159,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -178,7 +177,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["destinations"]), 10)
@@ -198,7 +197,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body)
@@ -211,7 +210,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), number_destinations)
self.assertNotIn("next_token", channel.json_body)
@@ -224,7 +223,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -238,7 +237,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_destinations)
self.assertEqual(len(channel.json_body["destinations"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -255,7 +254,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_destinations, len(channel.json_body["destinations"]))
self.assertEqual(number_destinations, channel.json_body["total"])
@@ -290,7 +289,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_destination_list))
returned_order = [
@@ -376,7 +375,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that destinations were returned
self.assertTrue("destinations" in channel.json_body)
@@ -418,7 +417,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"])
# Check that all fields are available
@@ -435,7 +434,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("sub0.example.com", channel.json_body["destination"])
self.assertEqual(0, channel.json_body["retry_last_ts"])
self.assertEqual(0, channel.json_body["retry_interval"])
@@ -452,7 +451,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
retry_timings = self.get_success(
self.store.get_destination_retry_timings("sub0.example.com")
@@ -469,7 +468,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"The retry timing does not need to be reset for this destination.",
channel.json_body["error"],
@@ -561,7 +560,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -574,7 +573,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -584,7 +583,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -594,7 +593,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid destination
@@ -604,7 +603,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -619,7 +618,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 3)
self.assertEqual(channel.json_body["next_token"], "3")
@@ -637,7 +636,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 5)
self.assertNotIn("next_token", channel.json_body)
@@ -655,7 +654,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(channel.json_body["next_token"], "8")
self.assertEqual(len(channel.json_body["rooms"]), 5)
@@ -673,7 +672,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body)
+ self.assertEqual(200, channel_asc.code, msg=channel_asc.json_body)
self.assertEqual(channel_asc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"]))
self._check_fields(channel_asc.json_body["rooms"])
@@ -685,7 +684,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body)
+ self.assertEqual(200, channel_desc.code, msg=channel_desc.json_body)
self.assertEqual(channel_desc.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"]))
self._check_fields(channel_desc.json_body["rooms"])
@@ -711,7 +710,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body)
@@ -724,7 +723,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), number_rooms)
self.assertNotIn("next_token", channel.json_body)
@@ -737,7 +736,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 4)
self.assertEqual(channel.json_body["next_token"], "4")
@@ -751,7 +750,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(len(channel.json_body["rooms"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -767,7 +766,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_rooms)
self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
self._check_fields(channel.json_body["rooms"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index e909e444ac..aadb31ca83 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from http import HTTPStatus
from parameterized import parameterized
@@ -60,7 +59,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("DELETE", url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -81,16 +80,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_does_not_exist(self) -> None:
"""
- Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a media that does not exist returns a 404
"""
url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
@@ -100,12 +95,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a media that is not a local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
@@ -115,7 +110,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_delete_media(self) -> None:
@@ -131,7 +126,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -151,11 +146,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Should be successful
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
- % server_and_media_id
+ "Expected to receive a 200 on accessing media: %s" % server_and_media_id
),
)
@@ -172,7 +166,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -189,10 +183,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% server_and_media_id
),
)
@@ -231,11 +225,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -251,16 +241,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_is_not_local(self) -> None:
"""
- Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for media that is not local returns a 400
"""
url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
@@ -270,7 +256,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_missing_parameter(self) -> None:
@@ -283,11 +269,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"]
@@ -303,11 +285,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -320,11 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
@@ -338,11 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.",
@@ -355,11 +325,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
@@ -388,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
media_id,
@@ -413,7 +379,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -425,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -449,7 +415,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -460,7 +426,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -485,7 +451,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"avatar_url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -493,7 +459,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -504,7 +470,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -530,7 +496,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
content={"url": "mxc://%s" % (server_and_media_id,)},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
now_ms = self.clock.time_msec()
channel = self.make_request(
@@ -538,7 +504,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self._access_media(server_and_media_id)
@@ -549,7 +515,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual(
server_and_media_id.split("/")[1],
@@ -569,7 +535,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -602,10 +568,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
if expect_success:
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.OK on accessing media: %s"
+ "Expected to receive a 200 on accessing media: %s"
% server_and_media_id
),
)
@@ -613,10 +579,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertTrue(os.path.exists(local_path))
else:
self.assertEqual(
- HTTPStatus.NOT_FOUND,
+ 404,
channel.code,
msg=(
- "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s"
+ "Expected to receive a 404 on accessing deleted media: %s"
% (server_and_media_id)
),
)
@@ -648,7 +614,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -668,11 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
b"{}",
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
@@ -689,11 +651,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self) -> None:
@@ -712,7 +670,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -726,7 +684,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -753,7 +711,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# verify that is not in quarantine
@@ -785,7 +743,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
upload_resource,
SMALL_PNG,
tok=self.admin_user_tok,
- expect_code=HTTPStatus.OK,
+ expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -801,11 +759,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url % (action, self.media_id), b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"])
@@ -822,11 +776,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_protect_media(self) -> None:
@@ -845,7 +795,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -859,7 +809,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
@@ -895,7 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -914,11 +864,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -931,11 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts must be a positive integer.",
@@ -948,11 +890,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual(
"Query parameter before_ts you provided is from the year 1970. "
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 8354250ec2..8f8abc21c7 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -13,7 +13,6 @@
# limitations under the License.
import random
import string
-from http import HTTPStatus
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -74,11 +73,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_no_auth(self) -> None:
"""Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self) -> None:
@@ -89,11 +84,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self) -> None:
@@ -105,7 +96,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -129,7 +120,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
@@ -150,7 +141,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 16)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -168,11 +159,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self) -> None:
@@ -188,11 +175,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self) -> None:
@@ -207,7 +190,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body)
+ self.assertEqual(200, channel1.code, msg=channel1.json_body)
channel2 = self.make_request(
"POST",
@@ -215,7 +198,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
data,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
+ self.assertEqual(400, channel2.code, msg=channel2.json_body)
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self) -> None:
@@ -251,7 +234,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
# Should fail with negative integer
@@ -262,7 +245,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.BAD_REQUEST,
+ 400,
channel.code,
msg=channel.json_body,
)
@@ -275,11 +258,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self) -> None:
@@ -291,11 +270,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() - 10000},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with float
@@ -305,11 +280,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": self.clock.time_msec() + 1000000.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self) -> None:
@@ -321,7 +292,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 64},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["token"]), 64)
# Should fail with 0
@@ -331,11 +302,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -345,11 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a float
@@ -359,11 +322,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 8.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with 65
@@ -373,11 +332,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"length": 65},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# UPDATING
@@ -389,11 +344,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self) -> None:
@@ -404,11 +355,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self) -> None:
@@ -420,11 +367,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self) -> None:
@@ -439,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -450,7 +393,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 0},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 0)
self.assertIsNone(channel.json_body["expiry_time"])
@@ -461,7 +404,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -472,11 +415,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": 1.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail with a negative integer
@@ -486,11 +425,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"uses_allowed": -5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self) -> None:
@@ -506,7 +441,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -517,7 +452,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": None},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIsNone(channel.json_body["expiry_time"])
self.assertIsNone(channel.json_body["uses_allowed"])
@@ -529,11 +464,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": past_time},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# Should fail a float
@@ -543,11 +474,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{"expiry_time": new_expiry_time + 0.5},
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self) -> None:
@@ -568,7 +495,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
@@ -589,11 +516,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
# DELETING
@@ -605,11 +528,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self) -> None:
@@ -620,11 +539,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self) -> None:
@@ -636,11 +551,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self) -> None:
@@ -655,7 +566,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# GETTING ONE
@@ -666,11 +577,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url + "/1234", # Token doesn't exist but that doesn't matter
{},
)
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self) -> None:
@@ -682,7 +589,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -697,11 +604,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.NOT_FOUND,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self) -> None:
@@ -716,7 +619,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["token"], token)
self.assertIsNone(channel.json_body["uses_allowed"])
self.assertIsNone(channel.json_body["expiry_time"])
@@ -728,11 +631,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_list_no_auth(self) -> None:
"""Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self) -> None:
@@ -743,11 +642,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
{},
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self) -> None:
@@ -762,7 +657,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
token_info = channel.json_body["registration_tokens"][0]
self.assertEqual(token_info["token"], token)
@@ -780,11 +675,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def _test_list_query_parameter(self, valid: str) -> None:
"""Helper used to test both valid=true and valid=false."""
@@ -816,7 +707,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
token_info_1 = channel.json_body["registration_tokens"][0]
token_info_2 = channel.json_body["registration_tokens"][1]
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 989cbdb5e2..e0f5d54aba 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+import time
import urllib.parse
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
@@ -23,10 +24,11 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
-from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.pagination import PaginationHandler, PurgeStatus
from synapse.rest.client import directory, events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
+from synapse.util.stringutils import random_string
from tests import unittest
@@ -68,7 +70,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -78,7 +80,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -94,11 +96,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/rooms/%s" % "invalidroom"
@@ -109,7 +111,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -127,7 +129,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -145,7 +147,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -163,7 +165,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -178,7 +180,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self) -> None:
@@ -202,7 +204,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -233,7 +235,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -265,7 +267,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -296,7 +298,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
# The room is now blocked.
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
def test_shutdown_room_consent(self) -> None:
@@ -319,7 +321,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -337,7 +339,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -366,7 +368,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
{"history_visibility": "world_readable"},
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -383,7 +385,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -398,7 +400,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -494,7 +496,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_requester_is_no_admin(self, method: str, url: str) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -504,7 +506,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self) -> None:
@@ -522,7 +524,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -533,7 +535,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"])
@@ -546,7 +548,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
)
def test_room_is_not_valid(self, method: str, url: str) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
channel = self.make_request(
@@ -556,7 +558,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -574,7 +576,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -592,7 +594,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -610,7 +612,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self) -> None:
@@ -625,7 +627,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_delete_expired_status(self) -> None:
@@ -639,7 +641,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id1 = channel.json_body["delete_id"]
@@ -654,7 +656,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id2 = channel.json_body["delete_id"]
@@ -665,7 +667,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"])
@@ -682,7 +684,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
@@ -696,7 +698,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_delete_same_room_twice(self) -> None:
@@ -722,9 +724,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
- )
+ self.assertEqual(400, second_channel.code, msg=second_channel.json_body)
self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
self.assertEqual(
f"History purge already in progress for {self.room_id}",
@@ -733,7 +733,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
# get result of first call
first_channel.await_result()
- self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+ self.assertEqual(200, first_channel.code, msg=first_channel.json_body)
self.assertIn("delete_id", first_channel.json_body)
# check status after finish the task
@@ -764,7 +764,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -795,7 +795,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -827,7 +827,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -858,7 +858,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.room_id,
body="foo",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Test that room is not purged
@@ -876,7 +876,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -887,7 +887,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room
@@ -914,7 +914,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
content={"history_visibility": "world_readable"},
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -931,7 +931,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("delete_id", channel.json_body)
delete_id = channel.json_body["delete_id"]
@@ -942,7 +942,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, len(channel.json_body["results"]))
# Test that member has moved to new room
@@ -955,7 +955,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self._has_no_members(self.room_id)
# Assert we can no longer peek into the room
- self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
+ self._assert_peek(self.room_id, expect_code=403)
def _is_blocked(self, room_id: str, expect: bool = True) -> None:
"""Assert that the room is blocked or not"""
@@ -1026,9 +1026,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.url_status_by_room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
- )
+ self.assertEqual(200, channel_room_id.code, msg=channel_room_id.json_body)
self.assertEqual(1, len(channel_room_id.json_body["results"]))
self.assertEqual(
delete_id, channel_room_id.json_body["results"][0]["delete_id"]
@@ -1041,7 +1039,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel_delete_id.code,
msg=channel_delete_id.json_body,
)
@@ -1085,7 +1083,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
room_ids = []
for _ in range(total_rooms):
room_id = self.helper.create_room_as(
- self.admin_user, tok=self.admin_user_tok
+ self.admin_user,
+ tok=self.admin_user_tok,
+ is_public=True,
)
room_ids.append(room_id)
@@ -1100,7 +1100,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -1124,8 +1124,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("version", r)
self.assertIn("creator", r)
self.assertIn("encryption", r)
- self.assertIn("federatable", r)
- self.assertIn("public", r)
+ self.assertIs(r["federatable"], True)
+ self.assertIs(r["public"], True)
self.assertIn("join_rules", r)
self.assertIn("guest_access", r)
self.assertIn("history_visibility", r)
@@ -1186,7 +1186,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -1226,7 +1226,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self) -> None:
"""Test the correct attributes for a room are returned"""
@@ -1253,7 +1253,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1285,7 +1285,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1341,7 +1341,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1487,7 +1487,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
def _search_test(
expected_room_id: Optional[str],
search_term: str,
- expected_http_code: int = HTTPStatus.OK,
+ expected_http_code: int = 200,
) -> None:
"""Search for a room and check that the returned room's id is a match
@@ -1505,7 +1505,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that rooms were returned
@@ -1548,7 +1548,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo")
_search_test(None, "bar")
- _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST)
+ _search_test(None, "", expected_http_code=400)
# Test that the whole room id returns the room
_search_test(room_id_1, room_id_1)
@@ -1585,15 +1585,19 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
- room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
- room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_id_1 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=True
+ )
+ room_id_2 = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok, is_public=False
+ )
room_name_1 = "something"
room_name_2 = "else"
@@ -1618,7 +1622,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("room_id", channel.json_body)
self.assertIn("name", channel.json_body)
@@ -1638,7 +1642,11 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("history_visibility", channel.json_body)
self.assertIn("state_events", channel.json_body)
self.assertIn("room_type", channel.json_body)
+ self.assertIn("forgotten", channel.json_body)
+
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ self.assertIs(True, channel.json_body["federatable"])
+ self.assertIs(True, channel.json_body["public"])
def test_single_room_devices(self) -> None:
"""Test that `joined_local_devices` can be requested correctly"""
@@ -1650,7 +1658,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room
@@ -1664,7 +1672,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room
@@ -1676,7 +1684,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self) -> None:
@@ -1707,7 +1715,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
@@ -1720,7 +1728,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertCountEqual(
["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
@@ -1738,7 +1746,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got
@@ -1755,7 +1763,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1784,10 +1792,203 @@ class RoomTestCase(unittest.HomeserverTestCase):
# delete the rooms and get joined roomed membership
url = f"/_matrix/client/r0/rooms/{room_id}/joined_members"
channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+class RoomMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.user = self.register_user("foo", "pass")
+ self.user_tok = self.login("foo", "pass")
+ self.room_id = self.helper.create_room_as(self.user, tok=self.user_tok)
+
+ def test_timestamp_to_event(self) -> None:
+ """Test that providing the current timestamp can get the last event."""
+ self.helper.send(self.room_id, body="message 1", tok=self.user_tok)
+ second_event_id = self.helper.send(
+ self.room_id, body="message 2", tok=self.user_tok
+ )["event_id"]
+ ts = str(round(time.time() * 1000))
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/timestamp_to_event?dir=b&ts=%s"
+ % (self.room_id, ts),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("event_id", channel.json_body)
+ self.assertEqual(second_event_id, channel.json_body["event_id"])
+
+ def test_topo_token_is_accepted(self) -> None:
+ """Test Topo Token is accepted."""
+ token = "t1-0_0_0_0_0_0_0_0_0"
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("start", channel.json_body)
+ self.assertEqual(token, channel.json_body["start"])
+ self.assertIn("chunk", channel.json_body)
+ self.assertIn("end", channel.json_body)
+
+ def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
+ """Test that stream token is accepted for forward pagination."""
+ token = "s0_0_0_0_0_0_0_0_0"
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertIn("start", channel.json_body)
+ self.assertEqual(token, channel.json_body["start"])
+ self.assertIn("chunk", channel.json_body)
+ self.assertIn("end", channel.json_body)
+
+ def test_room_messages_backward(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ latest_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?dir=b" % (self.room_id,),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 6, [event["content"] for event in chunk])
+
+ # in backwards, this is the first event
+ self.assertEqual(chunk[0]["event_id"], latest_event_id)
+
+ def test_room_messages_forward(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ latest_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?dir=f" % (self.room_id,),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 6, [event["content"] for event in chunk])
+
+ # in forward, this is the last event
+ self.assertEqual(chunk[5]["event_id"], latest_event_id)
+
+ def test_room_messages_purge(self) -> None:
+ """Test room messages can be retrieved by an admin that isn't in the room."""
+ store = self.hs.get_datastores().main
+ pagination_handler = self.hs.get_pagination_handler()
+
+ # Send a first message in the room, which will be removed by the purge.
+ first_event_id = self.helper.send(
+ self.room_id, body="message 1", tok=self.user_tok
+ )["event_id"]
+ first_token = self.get_success(
+ store.get_topological_token_for_event(first_event_id)
+ )
+ first_token_str = self.get_success(first_token.to_string(store))
+
+ # Send a second message in the room, which won't be removed, and which we'll
+ # use as the marker to purge events before.
+ second_event_id = self.helper.send(
+ self.room_id, body="message 2", tok=self.user_tok
+ )["event_id"]
+ second_token = self.get_success(
+ store.get_topological_token_for_event(second_event_id)
+ )
+ second_token_str = self.get_success(second_token.to_string(store))
+
+ # Send a third event in the room to ensure we don't fall under any edge case
+ # due to our marker being the latest forward extremity in the room.
+ self.helper.send(self.room_id, body="message 3", tok=self.user_tok)
+
+ # Check that we get the first and second message when querying /messages.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+
+ # Purge every event before the second event.
+ purge_id = random_string(16)
+ pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+ self.get_success(
+ pagination_handler._purge_history(
+ purge_id=purge_id,
+ room_id=self.room_id,
+ token=second_token_str,
+ delete_local_events=True,
+ )
+ )
+
+ # Check that we only get the second message through /message now that the first
+ # has been purged.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
+
+ # Check that we get no event, but also no error, when querying /messages with
+ # the token that was pointing at the first event, because we don't have it
+ # anymore.
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/messages?from=%s&dir=b&filter=%s"
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ chunk = channel.json_body["chunk"]
+ self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
+
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -1813,7 +2014,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
@@ -1823,7 +2024,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1838,12 +2039,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1853,7 +2054,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self) -> None:
@@ -1868,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1876,7 +2077,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_does_not_exist(self) -> None:
"""
- Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
+ Check that unknown rooms/server return error 404.
"""
url = "/_synapse/admin/v1/join/!unknown:test"
@@ -1887,7 +2088,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(
"Can't join remote room because no servers that are in the room have been provided.",
channel.json_body["error"],
@@ -1895,7 +2096,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
def test_room_is_not_valid(self) -> None:
"""
- Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
+ Check that invalid room names, return an error 400.
"""
url = "/_synapse/admin/v1/join/invalidroom"
@@ -1906,7 +2107,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1924,7 +2125,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1934,7 +2135,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self) -> None:
@@ -1954,7 +2155,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self) -> None:
@@ -1982,7 +2183,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1995,7 +2196,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
content={"user_id": self.second_user_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2005,7 +2206,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self) -> None:
@@ -2025,7 +2226,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -2035,7 +2236,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self) -> None:
@@ -2069,7 +2270,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self) -> None:
@@ -2099,7 +2300,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -2158,7 +2359,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -2185,7 +2386,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -2211,7 +2412,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -2245,11 +2446,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins.
+ # We expect this to fail with a 400 as there are no room admins.
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
@@ -2279,7 +2480,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
@parameterized.expand([("PUT",), ("GET",)])
def test_requester_is_no_admin(self, method: str) -> None:
- """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
+ """If the user is not a server admin, an error 403 is returned."""
channel = self.make_request(
method,
@@ -2288,12 +2489,12 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand([("PUT",), ("GET",)])
def test_room_is_not_valid(self, method: str) -> None:
- """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
+ """Check that invalid room names, return an error 400."""
channel = self.make_request(
method,
@@ -2302,7 +2503,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -2319,7 +2520,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# `block` is not set
@@ -2330,7 +2531,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no content is send
@@ -2340,7 +2541,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
def test_block_room(self) -> None:
@@ -2354,7 +2555,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self._is_blocked(room_id, expect=True)
@@ -2378,7 +2579,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self._is_blocked(self.room_id, expect=True)
@@ -2394,7 +2595,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self._is_blocked(room_id, expect=False)
@@ -2418,7 +2619,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
content={"block": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self._is_blocked(self.room_id, expect=False)
@@ -2433,7 +2634,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["block"])
self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -2457,7 +2658,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url % room_id,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body["block"])
self.assertNotIn("user_id", channel.json_body)
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index dbcba2663c..a2f347f666 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
@@ -57,7 +56,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url)
self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
+ 401,
channel.code,
msg=channel.json_body,
)
@@ -72,7 +71,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.FORBIDDEN,
+ 403,
channel.code,
msg=channel.json_body,
)
@@ -80,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_does_not_exist(self) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
channel = self.make_request(
"POST",
self.url,
@@ -88,13 +87,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": "@unknown_person:test", "content": ""},
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
channel = self.make_request(
"POST",
@@ -106,7 +105,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Server notices can only be sent to local users", channel.json_body["error"]
)
@@ -122,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
# no content
@@ -133,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# no body
@@ -144,7 +143,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": ""},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'body' not in content", channel.json_body["error"])
@@ -156,10 +155,66 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
content={"user_id": self.other_user, "content": {"body": ""}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_avatar_url": "somthingwrong",
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_invalid_avatar_url(self) -> None:
+ """If avatar url in homeserver.yaml is invalid and
+ "check avatar size and mime type" is set, an error is returned.
+ TODO: Should be checked when reading the configuration."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+
+ self.assertEqual(500, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ @override_config(
+ {
+ "server_notices": {
+ "system_mxid_localpart": "notices",
+ "system_mxid_display_name": "test display name",
+ "system_mxid_avatar_url": None,
+ },
+ "max_avatar_size": "10M",
+ }
+ )
+ def test_displayname_is_set_avatar_is_none(self) -> None:
+ """
+ Tests that sending a server notices is successfully,
+ if a display_name is set, avatar_url is `None` and
+ "check avatar size and mime type" is set.
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ self._check_invite_and_join_status(self.other_user, 1, 0)
+
def test_server_notice_disabled(self) -> None:
"""Tests that server returns error if server notice is disabled"""
channel = self.make_request(
@@ -172,7 +227,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual(
"Server notices are not enabled on this server", channel.json_body["error"]
@@ -197,7 +252,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -226,7 +281,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has no new invites or memberships
self._check_invite_and_join_status(self.other_user, 0, 1)
@@ -260,7 +315,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -301,7 +356,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -341,7 +396,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg one"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -388,7 +443,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
"content": {"msgtype": "m.text", "body": "test msg two"},
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# user has one invite
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
@@ -538,7 +593,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/_matrix/client/r0/sync", access_token=token
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
# Get the messages
room = channel.json_body["rooms"]["join"][room_id]
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 7cb8ec57ba..b60f16b914 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -51,16 +50,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(
- HTTPStatus.UNAUTHORIZED,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
"""
- If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
+ If the user is not a server admin, an error 403 is returned.
"""
channel = self.make_request(
"GET",
@@ -69,11 +64,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_tok,
)
- self.assertEqual(
- HTTPStatus.FORBIDDEN,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -87,11 +78,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -101,11 +88,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -115,11 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from_ts
@@ -129,11 +108,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative until_ts
@@ -143,11 +118,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# until_ts smaller from_ts
@@ -157,11 +128,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# empty search term
@@ -171,11 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -185,11 +148,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -204,7 +163,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -222,7 +181,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -240,7 +199,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -262,7 +221,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -275,7 +234,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -288,7 +247,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -301,7 +260,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -318,7 +277,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"]))
@@ -415,7 +374,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media starting at `ts1` after creating first media
@@ -425,7 +384,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s" % (ts1,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
self._create_media(self.other_user_tok, 3)
@@ -440,7 +399,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
# filter media until `ts2` and earlier
@@ -449,7 +408,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?until_ts=%s" % (ts2,),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
def test_search_term(self) -> None:
@@ -461,7 +420,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 20)
# filter user 1 and 10-19 by `user_id`
@@ -470,7 +429,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foo_user_1",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 11)
# filter on this user in `displayname`
@@ -479,7 +438,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=bar_user_10",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10")
self.assertEqual(channel.json_body["total"], 1)
@@ -489,7 +448,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url + "?search_term=foobar",
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0)
def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
@@ -515,7 +474,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
for _ in range(number_media):
# Upload some media into the room
self.helper.upload_media(
- upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
def _check_fields(self, content: List[JsonDict]) -> None:
@@ -549,7 +508,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["user_id"] for row in channel.json_body["users"]]
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 12db68d564..e8c9457794 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@ import hmac
import os
import urllib.parse
from binascii import unhexlify
-from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock, patch
@@ -26,13 +25,13 @@ from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
-from synapse.rest.client import devices, login, logout, profile, room, sync
+from synapse.rest.client import devices, login, logout, profile, register, room, sync
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -42,14 +41,12 @@ from tests.unittest import override_config
class UserRegisterTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
profile.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.url = "/_synapse/admin/v1/register"
self.registration_handler = Mock()
@@ -79,7 +76,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"Shared secret registration is not enabled", channel.json_body["error"]
)
@@ -111,7 +108,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# 61 seconds
@@ -119,7 +116,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_register_incorrect_nonce(self) -> None:
@@ -142,7 +139,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self) -> None:
@@ -169,7 +166,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
def test_nonce_reuse(self) -> None:
@@ -192,13 +189,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
def test_missing_parts(self) -> None:
@@ -219,7 +216,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be an empty body present
channel = self.make_request("POST", self.url, {})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
#
@@ -229,28 +226,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Must be present
channel = self.make_request("POST", self.url, {"nonce": nonce()})
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
#
@@ -261,28 +258,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
body = {"nonce": nonce(), "username": "a"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
body = {"nonce": nonce(), "username": "a", "password": 1234}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
#
@@ -298,7 +295,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self) -> None:
@@ -323,11 +320,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob1:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob1", channel.json_body["displayname"])
# displayname is None
@@ -347,11 +344,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob2:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("bob2", channel.json_body["displayname"])
# displayname is empty
@@ -371,11 +368,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -394,11 +391,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob4:test/displayname")
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("Bob's Name", channel.json_body["displayname"])
@override_config(
@@ -442,12 +439,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request("POST", self.url, body)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
class UsersListTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -466,7 +462,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -478,7 +474,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_all_users(self) -> None:
@@ -494,7 +490,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])
@@ -508,7 +504,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
expected_user_id: Optional[str],
search_term: str,
search_field: Optional[str] = "name",
- expected_http_code: Optional[int] = HTTPStatus.OK,
+ expected_http_code: Optional[int] = 200,
) -> None:
"""Search for a user and check that the returned user's id is a match
@@ -530,7 +526,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
- if expected_http_code != HTTPStatus.OK:
+ if expected_http_code != 200:
return
# Check that users were returned
@@ -579,6 +575,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
def test_invalid_parameter(self) -> None:
"""
If parameters are invalid, an error is returned.
@@ -591,7 +597,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -601,7 +607,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
@@ -611,7 +617,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid deactivated
@@ -621,7 +627,17 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # invalid approved
+ channel = self.make_request(
+ "GET",
+ self.url + "?approved=not_bool",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by
@@ -631,7 +647,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -641,7 +657,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self) -> None:
@@ -659,7 +675,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
@@ -680,7 +696,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -701,7 +717,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
@@ -724,7 +740,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -737,7 +753,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
@@ -750,7 +766,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
@@ -764,7 +780,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -842,6 +858,99 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_filter_out_approved(self) -> None:
+ """Tests that the endpoint can filter out approved users."""
+ # Create our users.
+ self._create_users(2)
+
+ # Get the list of users.
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+
+ # Exclude the admin, because we don't want to accidentally un-approve the admin.
+ non_admin_user_ids = [
+ user["name"]
+ for user in channel.json_body["users"]
+ if user["name"] != self.admin_user
+ ]
+
+ self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
+
+ # Select a user and un-approve them. We do this rather than the other way around
+ # because, since these users are created by an admin, we consider them already
+ # approved.
+ not_approved_user = non_admin_user_ids[0]
+
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{not_approved_user}",
+ {"approved": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+
+ # Now get the list of users again, this time filtering out approved users.
+ channel = self.make_request(
+ "GET",
+ self.url + "?approved=false",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, channel.result)
+
+ non_admin_user_ids = [
+ user["name"]
+ for user in channel.json_body["users"]
+ if user["name"] != self.admin_user
+ ]
+
+ # We should only have our unapproved user now.
+ self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
+ self.assertEqual(not_approved_user, non_admin_user_ids[0])
+
+ def test_erasure_status(self) -> None:
+ # Create a new user.
+ user_id = self.register_user("eraseme", "eraseme")
+
+ # They should appear in the list users API, marked as not erased.
+ channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=true",
+ access_token=self.admin_user_tok,
+ )
+ users = {user["name"]: user for user in channel.json_body["users"]}
+ self.assertIs(users[user_id]["erased"], False)
+
+ # Deactivate that user, requesting erasure.
+ deactivate_account_handler = self.hs.get_deactivate_account_handler()
+ self.get_success(
+ deactivate_account_handler.deactivate_account(
+ user_id, erase_data=True, requester=create_requester(user_id)
+ )
+ )
+
+ # Repeat the list users query. They should now be marked as erased.
+ channel = self.make_request(
+ "GET",
+ self.url + "?deactivated=true",
+ access_token=self.admin_user_tok,
+ )
+ users = {user["name"]: user for user in channel.json_body["users"]}
+ self.assertIs(users[user_id]["erased"], True)
+
def _order_test(
self,
expected_user_list: List[str],
@@ -867,7 +976,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_user_list))
returned_order = [row["name"] for row in channel.json_body["users"]]
@@ -905,11 +1014,100 @@ class UsersListTestCase(unittest.HomeserverTestCase):
)
-class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+class UserDevicesTestCase(unittest.HomeserverTestCase):
+ """
+ Tests user device management-related Admin APIs.
+ """
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Set up an Admin user to query the Admin API with.
+ self.admin_user_id = self.register_user("admin", "pass", admin=True)
+ self.admin_user_token = self.login("admin", "pass")
+
+ # Set up a test user to query the devices of.
+ self.other_user_device_id = "TESTDEVICEID"
+ self.other_user_device_display_name = "My Test Device"
+ self.other_user_client_ip = "1.2.3.4"
+ self.other_user_user_agent = "EquestriaTechnology/123.0"
+
+ self.other_user_id = self.register_user("user", "pass", displayname="User1")
+ self.other_user_token = self.login(
+ "user",
+ "pass",
+ device_id=self.other_user_device_id,
+ additional_request_fields={
+ "initial_device_display_name": self.other_user_device_display_name,
+ },
+ )
+
+ # Have the "other user" make a request so that the "last_seen_*" fields are
+ # populated in the tests below.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/sync",
+ access_token=self.other_user_token,
+ client_ip=self.other_user_client_ip,
+ custom_headers=[
+ ("User-Agent", self.other_user_user_agent),
+ ],
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ def test_list_user_devices(self) -> None:
+ """
+ Tests that a user's devices and attributes are listed correctly via the Admin API.
+ """
+ # Request all devices of "other user"
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Double-check we got the single device expected
+ user_devices = channel.json_body["devices"]
+ self.assertEqual(len(user_devices), 1)
+ self.assertEqual(channel.json_body["total"], 1)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(user_devices[0])
+
+ # Request just a single device for "other user" by its ID
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
+ f"{self.other_user_device_id}",
+ access_token=self.admin_user_token,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Check that all the attributes of the device reported are as expected.
+ self._validate_attributes_of_device_response(channel.json_body)
+
+ def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
+ # Check that all device expected attributes are present
+ self.assertEqual(response["user_id"], self.other_user_id)
+ self.assertEqual(response["device_id"], self.other_user_device_id)
+ self.assertEqual(response["display_name"], self.other_user_device_display_name)
+ self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
+ self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
+ self.assertIsInstance(response["last_seen_ts"], int)
+ self.assertGreater(response["last_seen_ts"], 0)
+
+
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -941,7 +1139,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -952,7 +1150,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -962,12 +1160,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that deactivation for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -976,7 +1174,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_erase_is_not_bool(self) -> None:
@@ -991,18 +1189,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that deactivation for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only deactivate local users", channel.json_body["error"])
def test_deactivate_user_erase_true(self) -> None:
@@ -1017,12 +1215,13 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
+ self.assertFalse(channel.json_body["erased"])
# Deactivate and erase user
channel = self.make_request(
@@ -1032,7 +1231,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1041,12 +1240,13 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
+ self.assertTrue(channel.json_body["erased"])
self._is_erased("@user:test", True)
@@ -1066,7 +1266,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_erased("@user:test", True)
def test_deactivate_user_erase_false(self) -> None:
@@ -1081,7 +1281,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1096,7 +1296,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1105,7 +1305,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1135,7 +1335,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1150,7 +1350,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content={"erase": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1159,7 +1359,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1178,11 +1378,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
class UserRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
+ register.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -1220,7 +1420,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1230,12 +1430,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
channel = self.make_request(
@@ -1244,7 +1444,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
def test_invalid_parameter(self) -> None:
@@ -1259,7 +1459,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"admin": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
# deactivated not bool
@@ -1269,7 +1469,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": "not_bool"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not str
@@ -1279,7 +1479,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": True},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# password not length
@@ -1289,7 +1489,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"password": "x" * 513},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# user_type not valid
@@ -1299,7 +1499,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"user_type": "new type"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# external_ids not valid
@@ -1311,7 +1511,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"external_ids": {"auth_provider": "prov", "wrong_external_id": "id"}
},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1320,7 +1520,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": {"external_id": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
# threepids not valid
@@ -1330,7 +1530,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"medium": "email", "wrong_address": "id"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
channel = self.make_request(
@@ -1339,7 +1539,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": {"address": "value"}},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_get_user(self) -> None:
@@ -1352,7 +1552,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("User", channel.json_body["displayname"])
self._check_fields(channel.json_body)
@@ -1379,7 +1579,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1395,7 +1595,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1434,7 +1634,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1458,7 +1658,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1486,7 +1686,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# before limit of monthly active users is reached
channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
- if channel.code != HTTPStatus.OK:
+ if channel.code != 200:
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"]
)
@@ -1512,7 +1712,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False},
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1550,7 +1750,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Admin user is not blocked by mau anymore
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1585,7 +1785,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1626,7 +1826,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1666,7 +1866,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
@@ -1684,7 +1884,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "hahaha"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body)
def test_set_displayname(self) -> None:
@@ -1700,7 +1900,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1711,7 +1911,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1733,7 +1933,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1759,7 +1959,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1775,7 +1975,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1791,7 +1991,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"threepids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1818,7 +2018,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1837,7 +2037,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1859,7 +2059,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# other user has this two threepids
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["threepids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1878,7 +2078,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url_first_user,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
@@ -1907,7 +2107,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
# result does not always have the same sort order, therefore it becomes sorted
@@ -1939,7 +2139,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1958,7 +2158,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -1977,7 +2177,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"external_ids": []},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(0, len(channel.json_body["external_ids"]))
@@ -2006,7 +2206,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2032,7 +2232,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2064,7 +2264,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# must fail
- self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body)
+ self.assertEqual(409, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("External id is already in use.", channel.json_body["error"])
@@ -2075,7 +2275,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2093,7 +2293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertEqual(
@@ -2124,7 +2324,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -2139,7 +2339,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -2158,7 +2358,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -2188,7 +2388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
@@ -2204,7 +2404,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
@@ -2228,7 +2428,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -2237,7 +2437,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2261,7 +2461,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2271,7 +2471,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2295,7 +2495,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -2305,7 +2505,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self._is_erased("@user:test", False)
@@ -2326,7 +2526,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2337,7 +2537,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -2354,7 +2554,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": UserTypes.SUPPORT},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2365,7 +2565,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
@@ -2377,7 +2577,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"user_type": None},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2388,7 +2588,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
@@ -2407,7 +2607,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
- self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -2418,7 +2618,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"])
@@ -2431,7 +2631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -2440,13 +2640,111 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_approve_account(self) -> None:
+ """Tests that approving an account correctly sets the approved flag for the user."""
+ url = self.url_prefix % "@bob:test"
+
+ # Create the user using the client-server API since otherwise the user will be
+ # marked as approved automatically.
+ channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "bob",
+ "password": "test",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIs(False, channel.json_body["approved"])
+
+ # Approve user
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content={"approved": True},
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIs(True, channel.json_body["approved"])
+
+ # Check that the user is now approved
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIs(True, channel.json_body["approved"])
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_register_approved(self) -> None:
+ url = self.url_prefix % "@bob:test"
+
+ # Create user
+ channel = self.make_request(
+ "PUT",
+ url,
+ access_token=self.admin_user_tok,
+ content={"password": "abc123", "approved": True},
+ )
+
+ self.assertEqual(201, channel.code, msg=channel.json_body)
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(1, channel.json_body["approved"])
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@bob:test", channel.json_body["name"])
+ self.assertEqual(1, channel.json_body["approved"])
+
def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
@@ -2465,7 +2763,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": True},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
@@ -2486,11 +2784,13 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("avatar_url", content)
self.assertIn("admin", content)
self.assertIn("deactivated", content)
+ self.assertIn("erased", content)
self.assertIn("shadow_banned", content)
self.assertIn("creation_ts", content)
self.assertIn("appservice_id", content)
self.assertIn("consent_server_notice_sent", content)
self.assertIn("consent_version", content)
+ self.assertIn("consent_ts", content)
self.assertIn("external_ids", content)
# This key was removed intentionally. Ensure it is not accidentally re-included.
@@ -2498,7 +2798,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -2520,7 +2819,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2535,7 +2834,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -2549,7 +2848,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2565,7 +2864,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2581,7 +2880,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["joined_rooms"]))
@@ -2602,7 +2901,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
@@ -2649,13 +2948,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
class PushersRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -2678,7 +2976,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -2693,12 +2991,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
channel = self.make_request(
@@ -2707,12 +3005,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
@@ -2722,7 +3020,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_get_pushers(self) -> None:
@@ -2737,7 +3035,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
@@ -2749,7 +3047,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.other_user,
access_token=token_id,
kind="http",
@@ -2769,7 +3067,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
@@ -2784,7 +3082,6 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
class UserMediaRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -2808,7 +3105,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""Try to list media of an user without authentication."""
channel = self.make_request(method, self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
@@ -2822,12 +3119,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
- """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
+ """Tests that a lookup for a user that does not exist returns a 404"""
url = "/_synapse/admin/v1/users/@unknown_person:test/media"
channel = self.make_request(
method,
@@ -2835,12 +3132,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
- """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST"""
+ """Tests that a lookup for a user that is not a local returns a 400"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
channel = self.make_request(
@@ -2849,7 +3146,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_limit_GET(self) -> None:
@@ -2865,7 +3162,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 5)
self.assertEqual(channel.json_body["next_token"], 5)
@@ -2884,7 +3181,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 5)
self.assertEqual(len(channel.json_body["deleted_media"]), 5)
@@ -2901,7 +3198,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 15)
self.assertNotIn("next_token", channel.json_body)
@@ -2920,7 +3217,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 15)
self.assertEqual(len(channel.json_body["deleted_media"]), 15)
@@ -2937,7 +3234,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(channel.json_body["next_token"], 15)
self.assertEqual(len(channel.json_body["media"]), 10)
@@ -2956,7 +3253,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 10)
self.assertEqual(len(channel.json_body["deleted_media"]), 10)
@@ -2970,7 +3267,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid search order
@@ -2980,7 +3277,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative limit
@@ -2990,7 +3287,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
@@ -3000,7 +3297,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self) -> None:
@@ -3023,7 +3320,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -3036,7 +3333,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), number_media)
self.assertNotIn("next_token", channel.json_body)
@@ -3049,7 +3346,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 19)
self.assertEqual(channel.json_body["next_token"], 19)
@@ -3063,7 +3360,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], number_media)
self.assertEqual(len(channel.json_body["media"]), 1)
self.assertNotIn("next_token", channel.json_body)
@@ -3080,7 +3377,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["media"]))
@@ -3095,7 +3392,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["deleted_media"]))
@@ -3112,7 +3409,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["media"]))
self.assertNotIn("next_token", channel.json_body)
@@ -3138,7 +3435,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_media, channel.json_body["total"])
self.assertEqual(number_media, len(channel.json_body["deleted_media"]))
self.assertCountEqual(channel.json_body["deleted_media"], media_ids)
@@ -3283,7 +3580,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK
+ upload_resource, image_data, user_token, filename, expect_code=200
)
# Extract media ID from the response
@@ -3301,10 +3598,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- HTTPStatus.OK,
+ 200,
channel.code,
msg=(
- f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}"
+ f"Expected to receive a 200 on accessing media: {server_and_media_id}"
),
)
@@ -3350,7 +3647,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], len(expected_media_list))
returned_order = [row["media_id"] for row in channel.json_body["media"]]
@@ -3386,14 +3683,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.url, b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
return channel.json_body["access_token"]
def test_no_auth(self) -> None:
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self) -> None:
@@ -3402,7 +3699,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self) -> None:
"""Test that sending event as a user works."""
@@ -3427,7 +3724,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
@@ -3439,21 +3736,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout with the puppet token
channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_user_logout_all(self) -> None:
"""Tests that the target user calling `/logout/all` does *not* expire
@@ -3464,23 +3761,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the real user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should still work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# .. but the real user's tokens shouldn't
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self) -> None:
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3491,23 +3788,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# Test that we can successfully make a request
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Logout all with the admin user token
channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
@unittest.override_config(
{
@@ -3538,7 +3835,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
"com.example.test",
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Login in as the user
@@ -3559,7 +3856,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
room_id,
user=self.other_user,
tok=self.other_user_tok,
- expect_code=HTTPStatus.FORBIDDEN,
+ expect_code=403,
)
# Logging in as the other user and joining a room should work, even
@@ -3576,7 +3873,6 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
],
)
class WhoisRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3594,7 +3890,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self) -> None:
@@ -3609,12 +3905,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user2_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined]
@@ -3623,7 +3919,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only whois a local user", channel.json_body["error"])
def test_get_whois_admin(self) -> None:
@@ -3635,7 +3931,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
@@ -3650,13 +3946,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.url,
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["user_id"])
self.assertIn("devices", channel.json_body)
class ShadowBanRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3680,7 +3975,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request(method, self.url)
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
@@ -3691,18 +3986,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"])
def test_user_is_not_local(self, method: str) -> None:
"""
- Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self) -> None:
"""
@@ -3715,7 +4010,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(result.shadow_banned)
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is shadow-banned (and the cache was cleared).
@@ -3727,7 +4022,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared).
@@ -3737,7 +4032,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
class RateLimitTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3762,7 +4056,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
@@ -3778,13 +4072,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "POST", "DELETE"])
def test_user_does_not_exist(self, method: str) -> None:
"""
- Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
+ Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
@@ -3794,7 +4088,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(
@@ -3806,7 +4100,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
def test_user_is_not_local(self, method: str, error_msg: str) -> None:
"""
- Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
+ Tests that a lookup for a user that is not a local returns a 400
"""
url = (
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
@@ -3818,7 +4112,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self) -> None:
@@ -3833,7 +4127,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# messages_per_second is negative
@@ -3844,7 +4138,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"messages_per_second": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is a string
@@ -3855,7 +4149,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": "string"},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# burst_count is negative
@@ -3866,7 +4160,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
content={"burst_count": -1},
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_return_zero_when_null(self) -> None:
@@ -3891,7 +4185,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["messages_per_second"])
self.assertEqual(0, channel.json_body["burst_count"])
@@ -3905,7 +4199,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3916,7 +4210,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 10, "burst_count": 11},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(10, channel.json_body["messages_per_second"])
self.assertEqual(11, channel.json_body["burst_count"])
@@ -3927,7 +4221,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"messages_per_second": 20, "burst_count": 21},
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3937,7 +4231,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(20, channel.json_body["messages_per_second"])
self.assertEqual(21, channel.json_body["burst_count"])
@@ -3947,7 +4241,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
@@ -3957,13 +4251,12 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertNotIn("messages_per_second", channel.json_body)
self.assertNotIn("burst_count", channel.json_body)
class AccountDataTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -3982,7 +4275,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Try to get information of a user without authentication."""
channel = self.make_request("GET", self.url, {})
- self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self) -> None:
@@ -3995,7 +4288,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self) -> None:
@@ -4008,7 +4301,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self) -> None:
@@ -4021,7 +4314,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only look up local users", channel.json_body["error"])
def test_success(self) -> None:
@@ -4042,7 +4335,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
self.url,
access_token=self.admin_user_tok,
)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"a": 1}, channel.json_body["account_data"]["global"]["m.global"]
)
@@ -4050,3 +4343,183 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
{"b": 2},
channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"],
)
+
+
+class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.get_success(
+ self.store.record_user_external_id(
+ "the-auth-provider", "the-external-id", self.other_user
+ )
+ )
+ self.get_success(
+ self.store.record_user_external_id(
+ "another-auth-provider", "a:complex@external/id", self.other_user
+ )
+ )
+
+ def test_no_auth(self) -> None:
+ """Try to lookup a user without authentication."""
+ url = (
+ "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
+ )
+
+ channel = self.make_request(
+ "GET",
+ url,
+ )
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_binding_does_not_exist(self) -> None:
+ """Tests that a lookup for an external ID that does not exist returns a 404"""
+ url = "/_synapse/admin/v1/auth_providers/the-auth-provider/users/unknown-id"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_success(self) -> None:
+ """Tests a successful external ID lookup"""
+ url = (
+ "/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
+ )
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
+
+ def test_success_urlencoded(self) -> None:
+ """Tests a successful external ID lookup with an url-encoded ID"""
+ url = "/_synapse/admin/v1/auth_providers/another-auth-provider/users/a%3Acomplex%40external%2Fid"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
+
+
+class UsersByThreePidTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.get_success(
+ self.store.user_add_threepid(
+ self.other_user, "email", "user@email.com", 1, 1
+ )
+ )
+ self.get_success(
+ self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1)
+ )
+
+ def test_no_auth(self) -> None:
+ """Try to look up a user without authentication."""
+ url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ )
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_medium_does_not_exist(self) -> None:
+ """Tests that both a lookup for a medium that does not exist and a user that
+ doesn't exist with that third party ID returns a 404"""
+ # test for unknown medium
+ url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ # test for unknown user with a known medium
+ url = "/_synapse/admin/v1/threepid/email/users/unknown"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_success(self) -> None:
+ """Tests a successful medium + address lookup"""
+ # test for email medium with encoded value of user@email.com
+ url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
+
+ # test for msidn medium with encoded value of +1-12345678
+ url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ {"user_id": self.other_user},
+ channel.json_body,
+ )
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index b21f6d4689..30f12f1bff 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from http import HTTPStatus
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -40,7 +37,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
if username == "allowed":
return True
raise SynapseError(
- HTTPStatus.BAD_REQUEST,
+ 400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@@ -50,27 +47,23 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
def test_username_available(self) -> None:
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "allowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["available"])
def test_username_unavailable(self) -> None:
"""
- The endpoint should return a HTTPStatus.OK response if the username does not exist
+ The endpoint should return a 200 response if the username does not exist
"""
url = "%s?username=%s" % (self.url, "disallowed")
channel = self.make_request("GET", url, access_token=self.admin_user_tok)
- self.assertEqual(
- HTTPStatus.BAD_REQUEST,
- channel.code,
- msg=channel.json_body,
- )
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE")
self.assertEqual(channel.json_body["error"], "User ID already taken.")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 7ae926dc9c..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -488,7 +488,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -641,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
@@ -1001,7 +1001,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertEqual(expected_error, channel.json_body["error"])
+ self.assertIn(expected_error, channel.json_body["error"])
def _validate_token(self, link: str) -> None:
# Remove the host
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 05355c7fb6..208ec44829 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -20,7 +21,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
-from synapse.api.constants import LoginType
+from synapse.api.constants import ApprovalNoticeMedium, LoginType
+from synapse.api.errors import Codes, SynapseError
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -31,8 +33,8 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.utils import TEST_OIDC_CONFIG
-from tests.server import FakeChannel
+from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
+from tests.server import FakeChannel, make_request
from tests.unittest import override_config, skip_unless
@@ -464,9 +466,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
* checking that the original operation succeeds
"""
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# log the user in
remote_user_id = UserID.from_string(self.user).localpart
- login_resp = self.helper.login_via_oidc(remote_user_id)
+ login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
self.assertEqual(login_resp["user_id"], self.user)
# initiate a UI Auth process by attempting to delete the device
@@ -480,8 +484,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
# run the UIA-via-SSO flow
session_id = channel.json_body["session"]
- channel = self.helper.auth_via_oidc(
- {"sub": remote_user_id}, ui_auth_session_id=session_id
+ channel, _ = self.helper.auth_via_oidc(
+ fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
)
# that should serve a confirmation page
@@ -498,7 +502,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self) -> None:
- login_resp = self.helper.login_via_oidc("username")
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
@@ -521,7 +526,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
- login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_resp, _ = self.helper.login_via_oidc(
+ fake_oidc_server, UserID.from_string(self.user).localpart
+ )
self.assertEqual(login_resp["user_id"], self.user)
channel = self.delete_device(
@@ -538,8 +546,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# log the user in
- login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ login_resp, _ = self.helper.login_via_oidc(
+ fake_oidc_server, UserID.from_string(self.user).localpart
+ )
self.assertEqual(login_resp["user_id"], self.user)
# start a UI Auth flow by attempting to delete a device
@@ -552,8 +565,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
session_id = channel.json_body["session"]
# do the OIDC auth, but auth as the wrong user
- channel = self.helper.auth_via_oidc(
- {"sub": "wrong_user"}, ui_auth_session_id=session_id
+ channel, _ = self.helper.auth_via_oidc(
+ fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
)
# that should return a failure message
@@ -567,6 +580,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
body={"auth": {"session": session_id}},
)
+ @skip_unless(HAS_OIDC, "requires OIDC")
+ @override_config(
+ {
+ "oidc_config": TEST_OIDC_CONFIG,
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ },
+ }
+ )
+ def test_sso_not_approved(self) -> None:
+ """Tests that if we register a user via SSO while requiring approval for new
+ accounts, we still raise the correct error before logging the user in.
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_resp, _ = self.helper.login_via_oidc(
+ fake_oidc_server, "username", expected_status=403
+ )
+
+ self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
+ )
+
+ # Check that we didn't register a device for the user during the login attempt.
+ devices = self.get_success(
+ self.hs.get_datastores().main.get_devices_by_user("@username:test")
+ )
+
+ self.assertEqual(len(devices), 0)
+
class RefreshAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -589,23 +635,10 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
"""
return self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": refresh_token},
)
- def is_access_token_valid(self, access_token: str) -> bool:
- """
- Checks whether an access token is valid, returning whether it is or not.
- """
- code = self.make_request(
- "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
- ).code
-
- # Either 200 or 401 is what we get back; anything else is a bug.
- assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
- return code == HTTPStatus.OK
-
def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
@@ -691,7 +724,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
@@ -732,7 +765,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
@@ -802,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.reactor.advance(59.0)
# Both tokens should still be valid.
- self.assertTrue(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 61 s (just past 1 minute, the time of expiry)
self.reactor.advance(2.0)
# Only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 599 s (just shy of 10 minutes, the time of expiry)
self.reactor.advance(599.0 - 61.0)
# It's still the case that only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 601 s (just past 10 minutes, the time of expiry)
self.reactor.advance(2.0)
# Now neither token is valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(
+ nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@@ -961,7 +1002,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This first refresh should work properly
first_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -971,7 +1012,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This one as well, since the token in the first one was never used
second_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -981,7 +1022,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# This one should not, since the token from the first refresh is not valid anymore
third_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": first_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -1015,7 +1056,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
fourth_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": login_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -1027,7 +1068,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# But refreshing from the last valid refresh token still works
fifth_refresh_response = self.make_request(
"POST",
- "/_matrix/client/v1/refresh",
+ "/_matrix/client/v3/refresh",
{"refresh_token": second_refresh_response.json_body["refresh_token"]},
)
self.assertEqual(
@@ -1120,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# and no refresh token
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)
+
+
+def oidc_config(
+ id: str, with_localpart_template: bool, **kwargs: Any
+) -> Dict[str, Any]:
+ """Sample OIDC provider config used in backchannel logout tests.
+
+ Args:
+ id: IDP ID for this provider
+ with_localpart_template: Set to `true` to have a default localpart_template in
+ the `user_mapping_provider` config and skip the user mapping session
+ **kwargs: rest of the config
+
+ Returns:
+ A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
+ the HS config
+ """
+ config: Dict[str, Any] = {
+ "idp_id": id,
+ "idp_name": id,
+ "issuer": TEST_OIDC_ISSUER,
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["openid"],
+ }
+
+ if with_localpart_template:
+ config["user_mapping_provider"] = {
+ "config": {"localpart_template": "{{ user.sub }}"}
+ }
+ else:
+ config["user_mapping_provider"] = {"config": {}}
+
+ config.update(kwargs)
+
+ return config
+
+
+@skip_unless(HAS_OIDC, "Requires OIDC")
+class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
+ servlets = [
+ account.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+ # False, so synapse will see the requested uri as http://..., so using http in
+ # the public_baseurl stops Synapse trying to redirect to https.
+ config["public_baseurl"] = "http://synapse.test"
+
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resource_dict = super().create_resource_dict()
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
+ return resource_dict
+
+ def submit_logout_token(self, logout_token: str) -> FakeChannel:
+ return self.make_request(
+ "POST",
+ "/_synapse/client/oidc/backchannel_logout",
+ content=f"logout_token={logout_token}",
+ content_is_form=True,
+ )
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_simple_logout(self) -> None:
+ """
+ Receiving a logout token should logout the user
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, first_grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ first_access_token: str = login_resp["access_token"]
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+ login_resp, second_grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ second_access_token: str = login_resp["access_token"]
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ self.assertNotEqual(first_grant.sid, second_grant.sid)
+ self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+ # Logging out of the first session
+ logout_token = fake_oidc_server.generate_logout_token(first_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out of the second session
+ logout_token = fake_oidc_server.generate_logout_token(second_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_logout_during_login(self) -> None:
+ """
+ It should revoke login tokens when receiving a logout token
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ # Get an authentication, and logout before submitting the logout token
+ client_redirect_url = "https://x"
+ userinfo = {"sub": user}
+ channel, grant = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=True,
+ )
+
+ # expect a confirmation page
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.text_body,
+ )
+ assert m, channel.text_body
+ login_token = m.group(1)
+
+ # Submit a logout
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ # Now try to exchange the login token
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ # It should have failed
+ self.assertEqual(channel.code, 403)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=False,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_logout_during_mapping(self) -> None:
+ """
+ It should stop ongoing user mapping session when receiving a logout token
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ # Get an authentication, and logout before submitting the logout token
+ client_redirect_url = "https://x"
+ userinfo = {"sub": user}
+ channel, grant = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=True,
+ )
+
+ # Expect a user mapping page
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+
+ # We should have a user_mapping_session cookie
+ cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+ assert cookie_headers
+ cookies: Dict[str, str] = {}
+ for h in cookie_headers:
+ key, value = h.split(";")[0].split("=", maxsplit=1)
+ cookies[key] = value
+
+ user_mapping_session_id = cookies["username_mapping_session"]
+
+ # Getting that session should not raise
+ session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+ self.assertIsNotNone(session)
+
+ # Submit a logout
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ # Now it should raise
+ with self.assertRaises(SynapseError):
+ self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=False,
+ )
+ ]
+ }
+ )
+ def test_disabled(self) -> None:
+ """
+ Receiving a logout token should do nothing if it is disabled in the config
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ access_token: str = login_resp["access_token"]
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out shouldn't work
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 400)
+
+ # And the token should still be valid
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_no_sid(self) -> None:
+ """
+ Receiving a logout token without `sid` during the login should do nothing
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=False
+ )
+ access_token: str = login_resp["access_token"]
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out shouldn't work
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 400)
+
+ # And the token should still be valid
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ "first",
+ issuer="https://first-issuer.com/",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ ),
+ oidc_config(
+ "second",
+ issuer="https://second-issuer.com/",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ ),
+ ]
+ }
+ )
+ def test_multiple_providers(self) -> None:
+ """
+ It should be able to distinguish login tokens from two different IdPs
+ """
+ first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
+ second_server = self.helper.fake_oidc_server(
+ issuer="https://second-issuer.com/"
+ )
+ user = "john"
+
+ login_resp, first_grant = self.helper.login_via_oidc(
+ first_server, user, with_sid=True, idp_id="oidc-first"
+ )
+ first_access_token: str = login_resp["access_token"]
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+ login_resp, second_grant = self.helper.login_via_oidc(
+ second_server, user, with_sid=True, idp_id="oidc-second"
+ )
+ second_access_token: str = login_resp["access_token"]
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # `sid` in the fake providers are generated by a counter, so the first grant of
+ # each provider should give the same SID
+ self.assertEqual(first_grant.sid, second_grant.sid)
+ self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+ # Logging out of the first session
+ logout_token = first_server.generate_logout_token(first_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out of the second session
+ logout_token = second_server.generate_logout_token(second_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index aa98222434..d80eea17d3 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase):
self.reactor.advance(43200)
self.get_success(self.handler.get_device(user_id, "abc"))
self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError)
+
+
+class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def test_PUT(self) -> None:
+ """Sanity-check that we can PUT a dehydrated device.
+
+ Detects https://github.com/matrix-org/synapse/issues/14334.
+ """
+ alice = self.register_user("alice", "correcthorse")
+ token = self.login(alice, "correcthorse")
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device",
+ {
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device",
+ }
+ },
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ device_id = channel.json_body.get("device_id")
+ self.assertIsInstance(device_id, str)
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
@@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self) -> None:
@@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.hs.is_mine = _is_mine
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self) -> None:
@@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self) -> None:
@@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"404")
+ self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
@@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
# No ID also returns an invalid_id error
def test_get_filter_no_id(self) -> None:
@@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index dc17c9d113..b0c8215744 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -25,7 +25,6 @@ from tests import unittest
class IdentityTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -33,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
@@ -54,6 +52,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
+ "id_access_token": tok,
}
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index bbc8e74243..741fecea77 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -19,6 +19,7 @@ from synapse.rest import admin
from synapse.rest.client import keys, login
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
Codes.BAD_JSON,
channel.result,
)
+
+ def test_key_query_cancellation(self) -> None:
+ """
+ Tests that /keys/query is cancellable and does not swallow the
+ CancelledError.
+ """
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+
+ bob = self.register_user("bob", "uncle")
+
+ channel = make_request_with_cancellation_test(
+ "test_key_query_cancellation",
+ self.reactor,
+ self.site,
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ # Empty list means we request keys for all bob's devices
+ bob: [],
+ },
+ },
+ token=alice_token,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertIn(bob, channel.json_body["device_keys"])
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a2958f6959..ff5baa9f0a 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,7 +13,6 @@
# limitations under the License.
import time
import urllib.parse
-from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -24,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
+from synapse.api.constants import ApprovalNoticeMedium, LoginType
+from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
@@ -35,7 +36,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
-from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
@@ -95,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
logout.register_servlets,
devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
+ register.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -134,10 +136,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -152,7 +154,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -179,10 +181,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -197,7 +199,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -224,10 +226,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -242,7 +244,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
@@ -250,7 +252,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal
@@ -261,20 +263,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -288,7 +290,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -296,7 +298,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False)
@@ -307,7 +309,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
set(channel.json_body.keys()),
@@ -330,7 +332,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token,
content={"auth": auth},
)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,20 +343,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@@ -367,20 +369,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
@@ -407,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_require_approval(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
+ params = {
+ "type": LoginType.PASSWORD,
+ "identifier": {"type": "m.id.user", "user": "kermit"},
+ "password": "monkey",
+ }
+ channel = self.make_request("POST", LOGIN_URL, params)
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
@@ -466,7 +506,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [
"m.login.cas",
@@ -494,14 +534,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None)
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8")
@@ -530,7 +570,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas",
shorthand=False,
)
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
@@ -555,7 +595,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
@@ -572,21 +612,24 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
- # pick the default OIDC provider
- channel = self.make_request(
- "GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=oidc",
- )
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ with fake_oidc_server.patch_homeserver(hs=self.hs):
+ # pick the default OIDC provider
+ channel = self.make_request(
+ "GET",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=oidc",
+ )
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
- self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
@@ -603,10 +646,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
TEST_CLIENT_REDIRECT_URL,
)
- channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+ channel, _ = self.helper.complete_oidc_auth(
+ fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
+ )
# that should serve a confirmation page
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +679,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
- self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+ self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,25 +688,28 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx")
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
- channel = self._make_sso_redirect_request("oidc")
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ with fake_oidc_server.patch_homeserver(hs=self.hs):
+ channel = self._make_sso_redirect_request("oidc")
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
- self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+ self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
"""Send a request to /_matrix/client/r0/login/sso/redirect
@@ -765,7 +813,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@@ -878,17 +926,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -897,7 +945,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -907,7 +955,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -916,7 +964,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@@ -925,12 +973,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -939,7 +987,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -949,7 +997,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@@ -957,12 +1005,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -971,7 +1019,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -981,7 +1029,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -991,20 +1039,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -1086,12 +1134,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -1152,7 +1200,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
@@ -1166,7 +1214,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
@@ -1180,7 +1228,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
@@ -1194,7 +1242,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
@@ -1208,7 +1256,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
@skip_unless(HAS_OIDC, "requires OIDC")
@@ -1240,13 +1288,17 @@ class UsernamePickerTestCase(HomeserverTestCase):
def test_username_picker(self) -> None:
"""Test the happy path of a username picker flow."""
+ fake_oidc_server = self.helper.fake_oidc_server()
+
# do the start of the login flow
- channel = self.helper.auth_via_oidc(
- {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+ channel, _ = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ {"sub": "tester", "displayname": "Jonny"},
+ TEST_CLIENT_REDIRECT_URL,
)
# that should redirect to the username picker
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
@@ -1290,7 +1342,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))),
],
)
- self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+ self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1300,7 +1352,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
- self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+ self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1325,5 +1377,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
- self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+ self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
new file mode 100644
index 0000000000..c2e1e08811
--- /dev/null
+++ b/tests/rest/client/test_login_token_request.py
@@ -0,0 +1,134 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, login_token_request
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+
+
+class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ login_token_request.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.registration.enable_registration = True
+ self.hs.config.registration.registrations_require_3pid = []
+ self.hs.config.registration.auto_join_rooms = []
+ self.hs.config.captcha.enable_registration_captcha = False
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = "user123"
+ self.password = "password"
+
+ def test_disabled(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 400)
+
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 400)
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_require_auth(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 401)
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_uia_on(self) -> None:
+ user_id = self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 401)
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ session = channel.json_body["session"]
+
+ uia = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.password,
+ "session": session,
+ },
+ }
+
+ channel = self.make_request("POST", endpoint, uia, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 300)
+
+ login_token = channel.json_body["login_token"]
+
+ channel = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ @override_config(
+ {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+ )
+ def test_uia_off(self) -> None:
+ user_id = self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 300)
+
+ login_token = channel.json_body["login_token"]
+
+ channel = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3882_enabled": True,
+ "msc3882_ui_auth": False,
+ "msc3882_token_timeout": "15s",
+ }
+ }
+ )
+ def test_expires_in(self) -> None:
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", endpoint, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 0000000000..0b8fcb0c47
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,76 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest as stdlib_unittest
+
+from pydantic import BaseModel, ValidationError
+from typing_extensions import Literal
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
+ class Model(BaseModel):
+ medium: Literal["email", "msisdn"]
+
+ def test_accepts_valid_medium_string(self) -> None:
+ """Sanity check that Pydantic behaves sensibly with an enum-of-str
+
+ This is arguably more of a test of a class that inherits from str and Enum
+ simultaneously.
+ """
+ model = self.Model.parse_obj({"medium": "email"})
+ self.assertEqual(model.medium, "email")
+
+ def test_rejects_invalid_medium_value(self) -> None:
+ with self.assertRaises(ValidationError):
+ self.Model.parse_obj({"medium": "interpretive_dance"})
+
+ def test_rejects_invalid_medium_type(self) -> None:
+ with self.assertRaises(ValidationError):
+ self.Model.parse_obj({"medium": 123})
+
+
+class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
+ base_request = {
+ "client_secret": "hunter2",
+ "email": "alice@wonderland.com",
+ "send_attempt": 1,
+ }
+
+ def test_token_required_if_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ }
+ )
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": None,
+ }
+ )
+
+ def test_token_typechecked_when_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": 1337,
+ }
+ )
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..5dfe44defb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class RedactionsTestCase(HomeserverTestCase):
@@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase):
)
def _redact_event(
- self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
+ self,
+ access_token: str,
+ room_id: str,
+ event_id: str,
+ expect_code: int = 200,
+ with_relations: Optional[List[str]] = None,
) -> JsonDict:
"""Helper function to send a redaction event.
@@ -75,13 +81,19 @@ class RedactionsTestCase(HomeserverTestCase):
"""
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
- channel = self.make_request("POST", path, content={}, access_token=access_token)
- self.assertEqual(int(channel.result["code"]), expect_code)
+ request_content = {}
+ if with_relations:
+ request_content["org.matrix.msc3912.with_relations"] = with_relations
+
+ channel = self.make_request(
+ "POST", path, request_content, access_token=access_token
+ )
+ self.assertEqual(channel.code, expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
@@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase):
# These should all succeed, even though this would be denied by
# the standard message ratelimiter
self._redact_event(self.mod_access_token, self.room_id, msg_id)
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_relations(self) -> None:
+ """Tests that we can redact the relations of an event at the same time as the
+ event itself.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "hello"},
+ tok=self.mod_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send an edit to this root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": " * hello world",
+ "m.new_content": {
+ "body": "hello world",
+ "msgtype": "m.text",
+ },
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.REPLACE,
+ },
+ "msgtype": "m.text",
+ },
+ tok=self.mod_access_token,
+ )
+ edit_event_id = res["event_id"]
+
+ # Also send a threaded message whose root is the same as the edit's.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 1",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ threaded_event_id = res["event_id"]
+
+ # Also send a reaction, again with the same root.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Reaction,
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": root_event_id,
+ "key": "👍",
+ }
+ },
+ tok=self.mod_access_token,
+ )
+ reaction_event_id = res["event_id"]
+
+ # Redact the root event, specifying that we also want to delete events that
+ # relate to it with m.replace.
+ self._redact_event(
+ self.mod_access_token,
+ self.room_id,
+ root_event_id,
+ with_relations=[
+ RelationTypes.REPLACE,
+ RelationTypes.THREAD,
+ ],
+ )
+
+ # Check that the root event got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the edit got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, edit_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the threaded message got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, threaded_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the reaction did not get redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, reaction_event_id, self.mod_access_token
+ )
+ self.assertNotIn("redacted_because", event_dict, event_dict)
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_relations_no_perms(self) -> None:
+ """Tests that, when redacting a message along with its relations, if not all
+ the related messages can be redacted because of insufficient permissions, the
+ server still redacts all the ones that can be.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "root",
+ },
+ tok=self.other_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send a first threaded message, this one from the moderator. We do this for the
+ # first message with the m.thread relation (and not the last one) to ensure
+ # that, when the server fails to redact it, it doesn't stop there, and it
+ # instead goes on to redact the other one.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 1",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ first_threaded_event_id = res["event_id"]
+
+ # Send a second threaded message, this time from the user who'll perform the
+ # redaction.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 2",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.other_access_token,
+ )
+ second_threaded_event_id = res["event_id"]
+
+ # Redact the thread's root, and request that all threaded messages are also
+ # redacted. Send that request from the non-mod user, so that the first threaded
+ # event cannot be redacted.
+ self._redact_event(
+ self.other_access_token,
+ self.room_id,
+ root_event_id,
+ with_relations=[RelationTypes.THREAD],
+ )
+
+ # Check that the thread root got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.other_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the last message in the thread got redacted, despite failing to
+ # redact the one before it.
+ event_dict = self.helper.get_event(
+ self.room_id, second_threaded_event_id, self.other_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the message that was sent into the tread by the mod user is not
+ # redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, first_threaded_event_id, self.other_access_token
+ )
+ self.assertIn("body", event_dict["content"], event_dict)
+ self.assertEqual("message 1", event_dict["content"]["body"])
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_relations_txn_id_reuse(self) -> None:
+ """Tests that redacting a message using a transaction ID, then reusing the same
+ transaction ID but providing an additional list of relations to redact, is
+ effectively a no-op.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "root",
+ },
+ tok=self.mod_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send a first threaded message.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "I'm in a thread!",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ threaded_event_id = res["event_id"]
+
+ # Send a first redaction request which redacts only the root event.
+ channel = self.make_request(
+ method="PUT",
+ path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
+ content={},
+ access_token=self.mod_access_token,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Send a second redaction request which redacts the root event as well as
+ # threaded messages.
+ channel = self.make_request(
+ method="PUT",
+ path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
+ content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]},
+ access_token=self.mod_access_token,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check that the root event got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict)
+
+ # Check that the threaded message didn't get redacted (since that wasn't part of
+ # the original redaction).
+ event_dict = self.helper.get_event(
+ self.room_id, threaded_event_id, self.mod_access_token
+ )
+ self.assertIn("body", event_dict["content"], event_dict)
+ self.assertEqual("I'm in a thread!", event_dict["content"]["body"])
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index f8e64ce6ac..11cf3939d8 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -22,7 +22,11 @@ import pkg_resources
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
+from synapse.api.constants import (
+ APP_SERVICE_REGISTRATION_TYPE,
+ ApprovalNoticeMedium,
+ LoginType,
+)
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
@@ -70,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
@@ -91,7 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
@@ -100,20 +104,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None:
request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self) -> None:
@@ -132,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False})
@@ -142,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -153,7 +157,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self) -> None:
@@ -161,7 +165,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@@ -171,16 +175,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", url, b"{}")
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
@@ -194,16 +198,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self) -> None:
@@ -231,7 +235,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow.
@@ -248,7 +252,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -263,7 +267,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0
@@ -293,21 +297,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -361,7 +365,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session2,
}
channel = self.make_request(b"POST", self.url, params2)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -381,7 +385,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -415,7 +419,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -570,7 +574,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we only expect the dummy flow
@@ -586,14 +590,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
- "email": "https://id_server",
"msisdn": "https://id_server",
},
+ "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
self.assertCountEqual(
@@ -625,7 +629,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid
@@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3866": {
+ "enabled": True,
+ "require_approval_for_new_accounts": True,
+ }
+ }
+ }
+ )
+ def test_require_approval(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "register",
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(403, channel.code, channel.result)
+ self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
+ self.assertEqual(
+ ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
+ )
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
@@ -797,13 +827,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -823,12 +853,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v1/account_validity/validity"
request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey")
@@ -844,12 +874,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -868,18 +898,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
"enable_renewal_emails": False,
}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Try to log the user out
channel = self.make_request(b"POST", "/logout", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -954,7 +984,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -972,7 +1002,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -992,14 +1022,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"404", channel.result)
+ self.assertEqual(channel.code, 404, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1023,7 +1053,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1096,7 +1126,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1176,7 +1206,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self) -> None:
@@ -1185,7 +1215,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], False)
@override_config(
@@ -1201,10 +1231,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@@ -1212,4 +1242,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index d589f07314..b86f341ff5 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -654,6 +654,14 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# We also expect to get the original event (the id of which is self.parent_id)
+ # when requesting the unstable endpoint.
+ self.assertNotIn("original_event", channel.json_body)
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(
channel.json_body["original_event"]["event_id"], self.parent_id
)
@@ -728,7 +736,6 @@ class RelationsTestCase(BaseRelationsTestCase):
class RelationPaginationTestCase(BaseRelationsTestCase):
- @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -756,11 +763,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel.json_body["chunk"][0],
)
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
@@ -771,7 +773,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations"
- f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+ f"/{self.parent_id}?limit=1&dir=f",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -809,7 +811,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -827,6 +829,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
+ # Test forward pagination.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?dir=f&limit=3{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(found_event_ids, expected_event_ids)
+
def test_pagination_from_sync_and_messages(self) -> None:
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
@@ -999,7 +1027,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1035,7 +1063,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
def test_thread(self) -> None:
"""
@@ -1080,21 +1108,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
)
# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
)
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1142,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
def test_nested_thread(self) -> None:
"""
@@ -1495,6 +1523,26 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
+ def _get_threads(self) -> List[Tuple[str, str]]:
+ """Request the threads in the room and returns a list of thread ID and latest event ID."""
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ threads = channel.json_body["chunk"]
+ return [
+ (
+ t["event_id"],
+ t["unsigned"]["m.relations"][RelationTypes.THREAD]["latest_event"][
+ "event_id"
+ ],
+ )
+ for t in threads
+ ]
+
def test_redact_relation_annotation(self) -> None:
"""
Test that annotations of an event are properly handled after the
@@ -1539,58 +1587,82 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
The redacted event should not be included in bundled aggregations or
the response to relations.
"""
- channel = self._send_relation(
- RelationTypes.THREAD,
- EventTypes.Message,
- content={"body": "reply 1", "msgtype": "m.text"},
- )
- unredacted_event_id = channel.json_body["event_id"]
+ # Create a thread with a few events in it.
+ thread_replies = []
+ for i in range(3):
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": f"reply {i}", "msgtype": "m.text"},
+ )
+ thread_replies.append(channel.json_body["event_id"])
- # Note that the *last* event in the thread is redacted, as that gets
- # included in the bundled aggregation.
- channel = self._send_relation(
- RelationTypes.THREAD,
- EventTypes.Message,
- content={"body": "reply 2", "msgtype": "m.text"},
+ ##################################################
+ # Check the test data is configured as expected. #
+ ##################################################
+ self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
+ relations = self._get_bundled_aggregations()
+ self.assertDictContainsSubset(
+ {"count": 3, "current_user_participated": True},
+ relations[RelationTypes.THREAD],
+ )
+ # The latest event is the last sent event.
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ thread_replies[-1],
)
- to_redact_event_id = channel.json_body["event_id"]
- # Both relations exist.
- event_ids = self._get_related_events()
+ # There should be one thread, the latest event is the event that will be redacted.
+ self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])])
+
+ ##########################
+ # Redact the last event. #
+ ##########################
+ self._redact(thread_replies.pop())
+
+ # The thread should still exist, but the latest event should be updated.
+ self.assertEquals(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertDictContainsSubset(
- {
- "count": 2,
- "current_user_participated": True,
- },
+ {"count": 2, "current_user_participated": True},
relations[RelationTypes.THREAD],
)
- # And the latest event returned is the event that will be redacted.
+ # And the latest event is the last unredacted event.
self.assertEqual(
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
- to_redact_event_id,
+ thread_replies[-1],
)
+ self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])])
- # Redact one of the reactions.
- self._redact(to_redact_event_id)
+ ###########################################
+ # Redact the *first* event in the thread. #
+ ###########################################
+ self._redact(thread_replies.pop(0))
- # The unredacted relation should still exist.
- event_ids = self._get_related_events()
+ # Nothing should have changed (except the thread count).
+ self.assertEquals(self._get_related_events(), thread_replies)
relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [unredacted_event_id])
self.assertDictContainsSubset(
- {
- "count": 1,
- "current_user_participated": True,
- },
+ {"count": 1, "current_user_participated": True},
relations[RelationTypes.THREAD],
)
- # And the latest event is now the unredacted event.
+ # And the latest event is the last unredacted event.
self.assertEqual(
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
- unredacted_event_id,
+ thread_replies[-1],
)
+ self.assertEqual(self._get_threads(), [(self.parent_id, thread_replies[-1])])
+
+ ####################################
+ # Redact the last remaining event. #
+ ####################################
+ self._redact(thread_replies.pop(0))
+ self.assertEquals(thread_replies, [])
+
+ # The event should no longer be considered a thread.
+ self.assertEquals(self._get_related_events(), [])
+ self.assertEquals(self._get_bundled_aggregations(), {})
+ self.assertEqual(self._get_threads(), [])
def test_redact_parent_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
@@ -1649,7 +1721,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_parent_thread(self) -> None:
"""
Test that thread replies are still available when the root event is redacted.
@@ -1679,3 +1750,165 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
related_event_id,
)
+
+
+class ThreadsTestCase(BaseRelationsTestCase):
+ def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]:
+ return [
+ (
+ ev["event_id"],
+ ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"],
+ )
+ for ev in body["chunk"]
+ ]
+
+ def test_threads(self) -> None:
+ """Create threads and ensure the ordering is due to their latest event."""
+ # Create 2 threads.
+ thread_1 = self.parent_id
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+ thread_2 = res["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", parent_id=thread_2
+ )
+ reply_2 = channel.json_body["event_id"]
+
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ threads = self._get_threads(channel.json_body)
+ self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
+
+ # Update the first thread, the ordering should swap.
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ reply_3 = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ # Tuple of (thread ID, latest event ID) for each thread.
+ threads = self._get_threads(channel.json_body)
+ self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
+
+ def test_pagination(self) -> None:
+ """Create threads and paginate through them."""
+ # Create 2 threads.
+ thread_1 = self.parent_id
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+ thread_2 = res["event_id"]
+
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_2])
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ next_batch = channel.json_body.get("next_batch")
+ self.assertIsInstance(next_batch, str, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_1], channel.json_body)
+
+ self.assertNotIn("next_batch", channel.json_body, channel.json_body)
+
+ def test_include(self) -> None:
+ """Filtering threads to all or participated in should work."""
+ # Thread 1 has the user as the root event.
+ thread_1 = self.parent_id
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+
+ # Thread 2 has the user replying.
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+ thread_2 = res["event_id"]
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Thread 3 has the user not participating in.
+ res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
+ thread_3 = res["event_id"]
+ self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.test",
+ access_token=self.user2_token,
+ parent_id=thread_3,
+ )
+
+ # All threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(
+ thread_roots, [thread_3, thread_2, thread_1], channel.json_body
+ )
+
+ # Only participated threads.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
+
+ def test_ignored_user(self) -> None:
+ """Events from ignored users should be ignored."""
+ # Thread 1 has a reply from an ignored user.
+ thread_1 = self.parent_id
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+
+ # Thread 2 is created by an ignored user.
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+ thread_2 = res["event_id"]
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Ignore user2.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Only thread 1 is returned.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_1], channel.json_body)
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
new file mode 100644
index 0000000000..ad00a476e1
--- /dev/null
+++ b/tests/rest/client/test_rendezvous.py
@@ -0,0 +1,45 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest.client import rendezvous
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
+
+
+class RendezvousServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ rendezvous.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+ return self.hs
+
+ def test_disabled(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 400)
+
+ @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
+ def test_redirect(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 307)
+ self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index ad0d0209f7..7cb1017a4a 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -77,6 +77,4 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
)
- self.assertEqual(
- response_status, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(response_status, channel.code, msg=channel.result["body"])
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key=""
+ create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index aa2f578441..e919e089cb 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
import json
from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call
+from unittest.mock import Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
@@ -35,13 +35,15 @@ from synapse.api.constants import (
EventTypes,
Membership,
PublicRoomsFilterFields,
- RelationTypes,
RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
+from synapse.appservice import ApplicationService
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
-from synapse.rest.client import account, directory, login, profile, room, sync
+from synapse.rest.client import account, directory, login, profile, register, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
@@ -49,7 +51,10 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
+from tests.storage.test_stream import PaginationTestCase
from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import create_event
+from tests.unittest import override_config
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -710,7 +715,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(44, channel.resource_usage.db_txn_count)
+ self.assertEqual(33, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -723,7 +728,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(50, channel.resource_usage.db_txn_count)
+ self.assertEqual(36, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -867,6 +872,41 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
+ def _create_basic_room(self) -> Tuple[int, object]:
+ """
+ Tries to create a basic room and returns the response code.
+ """
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ return channel.code, channel.json_body
+
+ @override_config(
+ {
+ "rc_message": {"per_second": 0.2, "burst_count": 10},
+ }
+ )
+ def test_room_creation_ratelimiting(self) -> None:
+ """
+ Regression test for #14312, where ratelimiting was made too strict.
+ Clients should be able to create 10 rooms in a row
+ without hitting rate limits, using default rate limit config.
+ (We override rate limiting config back to its default value.)
+
+ To ensure we don't make ratelimiting too generous accidentally,
+ also check that we can't create an 11th room.
+ """
+
+ for _ in range(10):
+ code, json_body = self._create_basic_room()
+ self.assertEqual(code, HTTPStatus.OK, json_body)
+
+ # The 6th room hits the rate limit.
+ code, json_body = self._create_basic_room()
+ self.assertEqual(code, HTTPStatus.TOO_MANY_REQUESTS, json_body)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -1252,6 +1292,120 @@ class RoomJoinTestCase(RoomBase):
)
+class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ room.register_servlets,
+ synapse.rest.admin.register_servlets,
+ register.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.appservice_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
+
+ # Create a room as the appservice user.
+ args = {
+ "access_token": self.appservice.token,
+ "user_id": self.appservice_user,
+ }
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/createRoom?{urlparse.urlencode(args)}",
+ content={"visibility": "public"},
+ )
+
+ assert channel.code == 200
+ self.room = channel.json_body["room_id"]
+
+ self.main_store = self.hs.get_datastores().main
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_send_event_ts(self) -> None:
+ """Test sending a non-state event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/send/m.room.message/1234?"
+ + urlparse.urlencode(url_params),
+ content={"body": "test", "msgtype": "m.text"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+ def test_send_state_event_ts(self) -> None:
+ """Test sending a state event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.name?"
+ + urlparse.urlencode(url_params),
+ content={"name": "test"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+ def test_send_membership_event_ts(self) -> None:
+ """Test sending a membership event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.member/{self.appservice_user}?"
+ + urlparse.urlencode(url_params),
+ content={"membership": "join", "display_name": "test"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -1272,10 +1426,22 @@ class RoomJoinRatelimitTestCase(RoomBase):
)
def test_join_local_ratelimit(self) -> None:
"""Tests that local joins are actually rate-limited."""
- for _ in range(3):
- self.helper.create_room_as(self.user_id)
+ # Create 4 rooms
+ room_ids = [
+ self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4)
+ ]
- self.helper.create_room_as(self.user_id, expect_code=429)
+ joiner_user_id = self.register_user("joiner", "secret")
+ # Now make a new user try to join some of them.
+
+ # The user can join 3 rooms
+ for room_id in room_ids[0:3]:
+ self.helper.join(room_id, joiner_user_id)
+
+ # But the user cannot join a 4th room
+ self.helper.join(
+ room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS
+ )
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
@@ -2098,14 +2264,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
)
def make_public_rooms_request(
- self, room_types: Union[List[Union[str, None]], None]
+ self,
+ room_types: Optional[List[Union[str, None]]],
+ instance_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
- channel = self.make_request(
- "POST",
- self.url,
- {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
- self.token,
- )
+ body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
+ if instance_id:
+ body["third_party_instance_id"] = "test|test"
+
+ channel = self.make_request("POST", self.url, body, self.token)
+ self.assertEqual(channel.code, 200)
+
chunk = channel.json_body["chunk"]
count = channel.json_body["total_room_count_estimate"]
@@ -2115,31 +2284,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
chunk, count = self.make_public_rooms_request(None)
-
self.assertEqual(count, 2)
+ # Also check if there's no filter property at all in the body.
+ channel = self.make_request("POST", self.url, {}, self.token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["chunk"]), 2)
+ self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
+
+ chunk, count = self.make_public_rooms_request(None, "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_rooms_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), None)
+ chunk, count = self.make_public_rooms_request([None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), "m.space")
+ chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
chunk, count = self.make_public_rooms_request([])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request([], "test|test")
+ self.assertEqual(count, 0)
+
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"""Test that we correctly fallback to local filtering if a remote server
@@ -2779,149 +2966,20 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id
-class RelationsTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["experimental_features"] = {"msc3440_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user_id = self.register_user("test", "test")
- self.tok = self.login("test", "test")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
- self.helper.join(
- room=self.room_id, user=self.second_user_id, tok=self.second_tok
- )
-
- self.third_user_id = self.register_user("third", "test")
- self.third_tok = self.login("third", "test")
- self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
-
- # An initial event with a relation from second user.
- res = self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "Message 1"},
- tok=self.tok,
- )
- self.event_id_1 = res["event_id"]
- self.helper.send_event(
- room_id=self.room_id,
- type="m.reaction",
- content={
- "m.relates_to": {
- "rel_type": RelationTypes.ANNOTATION,
- "event_id": self.event_id_1,
- "key": "👍",
- }
- },
- tok=self.second_tok,
- )
-
- # Another event with a relation from third user.
- res = self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "Message 2"},
- tok=self.tok,
- )
- self.event_id_2 = res["event_id"]
- self.helper.send_event(
- room_id=self.room_id,
- type="m.reaction",
- content={
- "m.relates_to": {
- "rel_type": RelationTypes.REFERENCE,
- "event_id": self.event_id_2,
- }
- },
- tok=self.third_tok,
- )
-
- # An event with no relations.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "No relations"},
- tok=self.tok,
- )
-
- def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+class RelationsTestCase(PaginationTestCase):
+ def _filter_messages(self, filter: JsonDict) -> List[str]:
"""Make a request to /messages with a filter, returns the chunk of events."""
+ from_token = self.get_success(
+ self.from_token.to_string(self.hs.get_datastores().main)
+ )
channel = self.make_request(
"GET",
- "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+ f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- return channel.json_body["chunk"]
-
- def test_filter_relation_senders(self) -> None:
- # Messages which second user reacted to.
- filter = {"related_by_senders": [self.second_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
-
- # Messages which third user reacted to.
- filter = {"related_by_senders": [self.third_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_2)
-
- # Messages which either user reacted to.
- filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
- )
-
- def test_filter_relation_type(self) -> None:
- # Messages which have annotations.
- filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
-
- # Messages which have references.
- filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_2)
-
- # Messages which have either annotations or references.
- filter = {
- "related_by_rel_types": [
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- ]
- }
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
- )
-
- def test_filter_relation_senders_and_type(self) -> None:
- # Messages which second user reacted to.
- filter = {
- "related_by_senders": [self.second_user_id],
- "related_by_rel_types": [RelationTypes.ANNOTATION],
- }
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
class ContextTestCase(unittest.HomeserverTestCase):
@@ -3461,3 +3519,83 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_400_missing_param_without_id_access_token(self) -> None:
+ """
+ Test that a 3pid invite request returns 400 M_MISSING_PARAM
+ if we do not include id_access_token.
+ """
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "medium": "email",
+ "address": "teresa@example.com",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
+
+
+class TimestampLookupTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = {"msc3030_enabled": True}
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._storage_controllers = self.hs.get_storage_controllers()
+
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ def _inject_outlier(self, room_id: str) -> EventBase:
+ event, _context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=room_id,
+ type="m.test",
+ sender="@test_remote_user:remote",
+ )
+ )
+
+ event.internal_metadata.outlier = True
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
+ )
+ )
+ return event
+
+ def test_no_outliers(self) -> None:
+ """
+ Test to make sure `/timestamp_to_event` does not return `outlier` events.
+ We're unable to determine whether an `outlier` is next to a gap so we
+ don't know whether it's actually the closest event. Instead, let's just
+ ignore `outliers` with this endpoint.
+
+ This test is really seeing that we choose the non-`outlier` event behind the
+ `outlier`. Since the gap checking logic considers the latest message in the room
+ as *not* next to a gap, asking over federation does not come into play here.
+ """
+ room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok)
+
+ outlier_event = self._inject_outlier(room_id)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
+ access_token=self.room_owner_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # Make sure the outlier event is not returned
+ self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase):
channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
- {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ {
+ "id_server": "test",
+ "medium": "email",
+ "address": "test@test.test",
+ "id_access_token": "anytoken",
+ },
access_token=self.banned_access_token,
)
self.assertEqual(200, channel.code, channel.result)
@@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
@@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b085c50356..0af643ecd9 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
-from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -390,6 +389,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -408,7 +412,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_read_receipts(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -416,7 +419,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt to tell the server the first user's message was read
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -425,7 +428,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt())
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_public_receipt_can_override_private(self) -> None:
"""
Sending a public read receipt to the same event which has a private read
@@ -456,7 +458,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_receipt_cannot_override_public(self) -> None:
"""
Sending a private read receipt to the same event which has a public read
@@ -543,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
config = super().default_config()
config["experimental_features"] = {
"msc2654_enabled": True,
- "msc2285_enabled": True,
}
return config
@@ -624,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok,
)
@@ -700,7 +700,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(5)
res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
- # Make sure both m.read and org.matrix.msc2285.read.private advance
+ # Make sure both m.read and m.read.private advance
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@@ -712,16 +712,21 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # We test for both receipt types that influence notification counts
- @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE])
- def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None:
+ # We test for all three receipt types that influence notification counts
+ @parameterized.expand(
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ]
+ )
+ def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
# Join the new user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@@ -732,18 +737,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Read last event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # Make sure neither m.read nor org.matrix.msc2285.read.private make the
+ # Make sure neither m.read nor m.read.private make the
# read receipt go up to an older event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}",
{},
access_token=self.tok,
)
@@ -948,3 +953,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
+
+ def test_incremental_sync(self) -> None:
+ """Tests that activity in the room is properly filtered out of incremental
+ syncs.
+ """
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+ next_batch = channel.json_body["next_batch"]
+
+ self.helper.send(self.excluded_room_id, tok=self.tok)
+ self.helper.send(self.included_room_id, tok=self.tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}",
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 1083391b41..2e1b4753dc 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -156,7 +156,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
callback.assert_called_once()
@@ -174,7 +174,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
@@ -212,7 +212,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
access_token=self.tok,
)
# Check the error code
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, channel.result)
# Check the JSON body has had the `nasty` key injected
self.assertEqual(
channel.json_body,
@@ -261,7 +261,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
# ... and check that it got modified
@@ -270,7 +270,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
@@ -299,7 +299,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
orig_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -316,7 +316,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
edited_event_id = channel.json_body["event_id"]
# ... and check that they both got modified
@@ -325,7 +325,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
@@ -334,7 +334,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
@@ -380,7 +380,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
@@ -389,7 +389,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar")
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index 61b66d7685..fdc433a8b5 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=UserID.from_string(self.user_id),
from_key=0,
- limit=None,
+ # Limit is unused.
+ limit=0,
room_ids=[self.room_id],
is_guest=False,
)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 105d418698..8d6f2b6ff9 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,7 +31,6 @@ from typing import (
Tuple,
overload,
)
-from unittest.mock import patch
from urllib.parse import urlencode
import attr
@@ -46,8 +45,19 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_ISSUER = "https://issuer.test/"
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "issuer": TEST_OIDC_ISSUER,
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["openid"],
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
@attr.s(auto_attribs=True)
@@ -140,7 +150,7 @@ class RestHelper:
custom_headers=custom_headers,
)
- assert channel.result["code"] == b"%d" % expect_code, channel.result
+ assert channel.code == expect_code, channel.result
self.auth_user_id = temp_id
if expect_code == HTTPStatus.OK:
@@ -213,11 +223,9 @@ class RestHelper:
data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -312,11 +320,9 @@ class RestHelper:
data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -396,11 +402,46 @@ class RestHelper:
custom_headers=custom_headers,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
+ channel.result["body"],
+ )
+
+ return channel.json_body
+
+ def get_event(
+ self,
+ room_id: str,
+ event_id: str,
+ tok: Optional[str] = None,
+ expect_code: int = HTTPStatus.OK,
+ ) -> JsonDict:
+ """Request a specific event from the server.
+
+ Args:
+ room_id: the room in which the event was sent.
+ event_id: the event's ID.
+ tok: the token to request the event with.
+ expect_code: the expected HTTP status for the response.
+
+ Returns:
+ The event as a dict.
+ """
+ path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}"
+ if tok:
+ path = path + f"?access_token={tok}"
+
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ path,
+ )
+
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+ expect_code,
+ channel.code,
channel.result["body"],
)
@@ -449,11 +490,9 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -545,13 +584,62 @@ class RestHelper:
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
return channel.json_body
- def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ def whoami(
+ self,
+ access_token: str,
+ expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
+ ) -> JsonDict:
+ """Perform a 'whoami' request, which can be a quick way to check for access
+ token validity
+
+ Args:
+ access_token: The user token to use during the request
+ expect_code: The return code to expect from attempting the whoami request
+ """
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "account/whoami",
+ access_token=access_token,
+ )
+
+ assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
+ expect_code,
+ channel.code,
+ channel.result["body"],
+ )
+
+ return channel.json_body
+
+ def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
+ """Create a ``FakeOidcServer``.
+
+ This can be used in conjuction with ``login_via_oidc``::
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
+ """
+
+ return FakeOidcServer(
+ clock=self.hs.get_clock(),
+ issuer=issuer,
+ )
+
+ def login_via_oidc(
+ self,
+ fake_server: FakeOidcServer,
+ remote_user_id: str,
+ with_sid: bool = False,
+ idp_id: Optional[str] = None,
+ expected_status: int = 200,
+ ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
Returns the result of the final token login.
@@ -564,7 +652,14 @@ class RestHelper:
the normal places.
"""
client_redirect_url = "https://x"
- channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+ userinfo = {"sub": remote_user_id}
+ channel, grant = self.auth_via_oidc(
+ fake_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=with_sid,
+ idp_id=idp_id,
+ )
# expect a confirmation page
assert channel.code == HTTPStatus.OK, channel.result
@@ -586,15 +681,20 @@ class RestHelper:
"/login",
content={"type": "m.login.token", "token": login_token},
)
- assert channel.code == HTTPStatus.OK
- return channel.json_body
+ assert (
+ channel.code == expected_status
+ ), f"unexpected status in response: {channel.code}"
+ return channel.json_body, grant
def auth_via_oidc(
self,
+ fake_server: FakeOidcServer,
user_info_dict: JsonDict,
client_redirect_url: Optional[str] = None,
ui_auth_session_id: Optional[str] = None,
- ) -> FakeChannel:
+ with_sid: bool = False,
+ idp_id: Optional[str] = None,
+ ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Perform an OIDC authentication flow via a mock OIDC provider.
This can be used for either login or user-interactive auth.
@@ -618,6 +718,8 @@ class RestHelper:
the login redirect endpoint
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
of the UI auth.
+ with_sid: if True, generates a random `sid` (OIDC session ID)
+ idp_id: if set, explicitely chooses one specific IDP
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -627,14 +729,17 @@ class RestHelper:
cookies: Dict[str, str] = {}
- # if we're doing a ui auth, hit the ui auth redirect endpoint
- if ui_auth_session_id:
- # can't set the client redirect url for UI Auth
- assert client_redirect_url is None
- oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
- else:
- # otherwise, hit the login redirect endpoint
- oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+ with fake_server.patch_homeserver(hs=self.hs):
+ # if we're doing a ui auth, hit the ui auth redirect endpoint
+ if ui_auth_session_id:
+ # can't set the client redirect url for UI Auth
+ assert client_redirect_url is None
+ oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+ else:
+ # otherwise, hit the login redirect endpoint
+ oauth_uri = self.initiate_sso_login(
+ client_redirect_url, cookies, idp_id=idp_id
+ )
# we now have a URI for the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
@@ -642,17 +747,21 @@ class RestHelper:
# that synapse passes to the client.
oauth_uri_path, _ = oauth_uri.split("?", 1)
- assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+ assert oauth_uri_path == fake_server.authorization_endpoint, (
"unexpected SSO URI " + oauth_uri_path
)
- return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+ return self.complete_oidc_auth(
+ fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
+ )
def complete_oidc_auth(
self,
+ fake_serer: FakeOidcServer,
oauth_uri: str,
cookies: Mapping[str, str],
user_info_dict: JsonDict,
- ) -> FakeChannel:
+ with_sid: bool = False,
+ ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Mock out an OIDC authentication flow
Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@@ -663,50 +772,37 @@ class RestHelper:
Requires the OIDC callback resource to be mounted at the normal place.
Args:
+ fake_server: the fake OIDC server with which the auth should be done
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.
+ with_sid: if True, generates a random `sid` (OIDC session ID)
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
+
+ code, grant = fake_serer.start_authorization(
+ scope=params["scope"][0],
+ userinfo=user_info_dict,
+ client_id=params["client_id"][0],
+ redirect_uri=params["redirect_uri"][0],
+ nonce=params["nonce"][0],
+ with_sid=with_sid,
+ )
+ state = params["state"][0]
+
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
- urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
- )
-
- # before we hit the callback uri, stub out some methods in the http client so
- # that we don't have to handle full HTTPS requests.
- # (expected url, json response) pairs, in the order we expect them.
- expected_requests = [
- # first we get a hit to the token endpoint, which we tell to return
- # a dummy OIDC access token
- (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
- # and then one to the user_info endpoint, which returns our remote user id.
- (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
- ]
-
- async def mock_req(
- method: str,
- uri: str,
- data: Optional[dict] = None,
- headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- ):
- (expected_uri, resp_obj) = expected_requests.pop(0)
- assert uri == expected_uri
- resp = FakeResponse(
- code=HTTPStatus.OK,
- phrase=b"OK",
- body=json.dumps(resp_obj).encode("utf-8"),
- )
- return resp
+ urllib.parse.urlencode({"state": state, "code": code}),
+ )
- with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
self.hs.get_reactor(),
@@ -717,10 +813,13 @@ class RestHelper:
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
- return channel
+ return channel, grant
def initiate_sso_login(
- self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+ self,
+ client_redirect_url: Optional[str],
+ cookies: MutableMapping[str, str],
+ idp_id: Optional[str] = None,
) -> str:
"""Make a request to the login-via-sso redirect endpoint, and return the target
@@ -731,6 +830,7 @@ class RestHelper:
client_redirect_url: the client redirect URL to pass to the login redirect
endpoint
cookies: any cookies returned will be added to this dict
+ idp_id: if set, explicitely chooses one specific IDP
Returns:
the URI that the client gets redirected to (ie, the SSO server)
@@ -739,6 +839,12 @@ class RestHelper:
if client_redirect_url:
params["redirectUrl"] = client_redirect_url
+ uri = "/_matrix/client/r0/login/sso/redirect"
+ if idp_id is not None:
+ uri = f"{uri}/{idp_id}"
+
+ uri = f"{uri}?{urllib.parse.urlencode(params)}"
+
# hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
@@ -746,7 +852,7 @@ class RestHelper:
self.hs.get_reactor(),
self.site,
"GET",
- "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+ uri,
)
assert channel.code == 302
@@ -808,21 +914,3 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0]
return oauth_uri
-
-
-# an 'oidc_config' suitable for login_via_oidc.
-TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
-TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
-TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
-TEST_OIDC_CONFIG = {
- "enabled": True,
- "discover": False,
- "issuer": "https://issuer.test",
- "client_id": "test-client-id",
- "client_secret": "test-client-secret",
- "scopes": ["profile"],
- "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
- "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
- "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
- "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
-}
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index ac0ac06b7e..7f1fba1086 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
@@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def create_test_resource(self) -> Resource:
return create_resource_tree(
- {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+ {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource()
)
def expect_outgoing_key_request(
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 14af07c5af..23f227aed6 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -13,7 +13,9 @@
# limitations under the License.
import io
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional
+
+from matrix_common.types.mxc_uri import MXCUri
from twisted.test.proto_helpers import MemoryReactor
@@ -63,9 +65,9 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
last_accessed_ms: Optional[int],
is_quarantined: Optional[bool] = False,
is_protected: Optional[bool] = False,
- ) -> str:
+ ) -> MXCUri:
# "Upload" some media to the local media store
- mxc_uri = self.get_success(
+ mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_type="text/plain",
upload_name=None,
@@ -75,13 +77,11 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
)
)
- media_id = mxc_uri.split("/")[-1]
-
# Set the last recently accessed time for this media
if last_accessed_ms is not None:
self.get_success(
self.store.update_cached_last_access_time(
- local_media=(media_id,),
+ local_media=(mxc_uri.media_id,),
remote_media=(),
time_ms=last_accessed_ms,
)
@@ -92,7 +92,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
self.get_success(
self.store.quarantine_media_by_id(
server_name=self.hs.config.server.server_name,
- media_id=media_id,
+ media_id=mxc_uri.media_id,
quarantined_by="@theadmin:test",
)
)
@@ -101,18 +101,18 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# Mark this media as protected from quarantine
self.get_success(
self.store.mark_local_media_as_safe(
- media_id=media_id,
+ media_id=mxc_uri.media_id,
safe=True,
)
)
- return media_id
+ return mxc_uri
def _cache_remote_media_and_set_attributes(
media_id: str,
last_accessed_ms: Optional[int],
is_quarantined: Optional[bool] = False,
- ) -> str:
+ ) -> MXCUri:
# Pretend to cache some remote media
self.get_success(
self.store.store_cached_remote_media(
@@ -146,7 +146,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
)
)
- return media_id
+ return MXCUri(self.remote_server_name, media_id)
# Start with the local media store
self.local_recently_accessed_media = _create_media_and_set_attributes(
@@ -214,28 +214,16 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# Remote media should be unaffected.
self._assert_if_mxc_uris_purged(
purged=[
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_media,
- ),
- (self.hs.config.server.server_name, self.local_never_accessed_media),
+ self.local_not_recently_accessed_media,
+ self.local_never_accessed_media,
],
not_purged=[
- (self.hs.config.server.server_name, self.local_recently_accessed_media),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_quarantined_media,
- ),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_protected_media,
- ),
- (self.remote_server_name, self.remote_recently_accessed_media),
- (self.remote_server_name, self.remote_not_recently_accessed_media),
- (
- self.remote_server_name,
- self.remote_not_recently_accessed_quarantined_media,
- ),
+ self.local_recently_accessed_media,
+ self.local_not_recently_accessed_quarantined_media,
+ self.local_not_recently_accessed_protected_media,
+ self.remote_recently_accessed_media,
+ self.remote_not_recently_accessed_media,
+ self.remote_not_recently_accessed_quarantined_media,
],
)
@@ -261,49 +249,35 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# Remote media accessed <30 days ago should still exist.
self._assert_if_mxc_uris_purged(
purged=[
- (self.remote_server_name, self.remote_not_recently_accessed_media),
+ self.remote_not_recently_accessed_media,
],
not_purged=[
- (self.remote_server_name, self.remote_recently_accessed_media),
- (self.hs.config.server.server_name, self.local_recently_accessed_media),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_media,
- ),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_quarantined_media,
- ),
- (
- self.hs.config.server.server_name,
- self.local_not_recently_accessed_protected_media,
- ),
- (
- self.remote_server_name,
- self.remote_not_recently_accessed_quarantined_media,
- ),
- (self.hs.config.server.server_name, self.local_never_accessed_media),
+ self.remote_recently_accessed_media,
+ self.local_recently_accessed_media,
+ self.local_not_recently_accessed_media,
+ self.local_not_recently_accessed_quarantined_media,
+ self.local_not_recently_accessed_protected_media,
+ self.remote_not_recently_accessed_quarantined_media,
+ self.local_never_accessed_media,
],
)
def _assert_if_mxc_uris_purged(
- self, purged: Iterable[Tuple[str, str]], not_purged: Iterable[Tuple[str, str]]
+ self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri]
) -> None:
- def _assert_mxc_uri_purge_state(
- server_name: str, media_id: str, expect_purged: bool
- ) -> None:
+ def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not."""
- if server_name == self.hs.config.server.server_name:
+ if mxc_uri.server_name == self.hs.config.server.server_name:
found_media_dict = self.get_success(
- self.store.get_local_media(media_id)
+ self.store.get_local_media(mxc_uri.media_id)
)
else:
found_media_dict = self.get_success(
- self.store.get_cached_remote_media(server_name, media_id)
+ self.store.get_cached_remote_media(
+ mxc_uri.server_name, mxc_uri.media_id
+ )
)
- mxc_uri = f"mxc://{server_name}/{media_id}"
-
if expect_purged:
self.assertIsNone(
found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
@@ -315,7 +289,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
)
# Assert that the given MXC URIs have either been correctly purged or not.
- for server_name, media_id in purged:
- _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=True)
- for server_name, media_id in not_purged:
- _assert_mxc_uri_purge_state(server_name, media_id, expect_purged=False)
+ for mxc_uri in purged:
+ _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True)
+ for mxc_uri in not_purged:
+ _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False)
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index f38d7225f8..319ae8b1cc 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -14,6 +14,8 @@
import json
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
@@ -23,8 +25,16 @@ from synapse.util import Clock
from tests.unittest import HomeserverTestCase
+try:
+ import lxml
+except ImportError:
+ lxml = None
+
class OEmbedTests(HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.oembed = OEmbedProvider(hs)
@@ -36,7 +46,7 @@ class OEmbedTests(HomeserverTestCase):
def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
- result = self.parse_response({"version": version, "type": "link"})
+ result = self.parse_response({"version": version})
# An empty Open Graph response is an error, ensure the URL is included.
self.assertIn("og:url", result.open_graph_result)
@@ -49,3 +59,94 @@ class OEmbedTests(HomeserverTestCase):
result = self.parse_response({"version": version, "type": "link"})
# An empty Open Graph response is an error, ensure the URL is included.
self.assertEqual({}, result.open_graph_result)
+
+ def test_cache_age(self) -> None:
+ """Ensure a cache-age is parsed properly."""
+ # Correct-ish cache ages are allowed.
+ for cache_age in ("1", 1.0, 1):
+ result = self.parse_response({"cache_age": cache_age})
+ self.assertEqual(result.cache_age, 1000)
+
+ # Invalid cache ages are ignored.
+ for cache_age in ("invalid", {}):
+ result = self.parse_response({"cache_age": cache_age})
+ self.assertIsNone(result.cache_age)
+
+ # Cache age is optional.
+ result = self.parse_response({})
+ self.assertIsNone(result.cache_age)
+
+ @parameterized.expand(
+ [
+ ("title", "title"),
+ ("provider_name", "site_name"),
+ ("thumbnail_url", "image"),
+ ],
+ name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}",
+ )
+ def test_property(self, oembed_property: str, open_graph_property: str) -> None:
+ """Test properties which must be strings."""
+ result = self.parse_response({oembed_property: "test"})
+ self.assertIn(f"og:{open_graph_property}", result.open_graph_result)
+ self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test")
+
+ result = self.parse_response({oembed_property: 1})
+ self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result)
+
+ def test_author_name(self) -> None:
+ """Test the author_name property."""
+ result = self.parse_response({"author_name": "test"})
+ self.assertEqual(result.author_name, "test")
+
+ result = self.parse_response({"author_name": 1})
+ self.assertIsNone(result.author_name)
+
+ def test_rich(self) -> None:
+ """Test a type of rich."""
+ result = self.parse_response({"html": "test<img src='foo'>", "type": "rich"})
+ self.assertIn("og:description", result.open_graph_result)
+ self.assertIn("og:image", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:description"], "test")
+ self.assertEqual(result.open_graph_result["og:image"], "foo")
+
+ result = self.parse_response({"type": "rich"})
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ result = self.parse_response({"html": 1, "type": "rich"})
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ def test_photo(self) -> None:
+ """Test a type of photo."""
+ result = self.parse_response({"url": "test", "type": "photo"})
+ self.assertIn("og:image", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:image"], "test")
+
+ result = self.parse_response({"type": "photo"})
+ self.assertNotIn("og:image", result.open_graph_result)
+
+ result = self.parse_response({"url": 1, "type": "photo"})
+ self.assertNotIn("og:image", result.open_graph_result)
+
+ def test_video(self) -> None:
+ """Test a type of video."""
+ result = self.parse_response({"html": "test", "type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertIn("og:description", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:description"], "test")
+
+ result = self.parse_response({"type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ result = self.parse_response({"url": 1, "type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ def test_link(self) -> None:
+ """Test type of link."""
+ result = self.parse_response({"type": "link"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "website")
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index da325955f8..c0a2501742 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
from synapse.rest.health import HealthResource
from tests import unittest
@@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase):
def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index d8faafec75..2091b08d89 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -57,7 +55,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(channel.code, 404)
@unittest.override_config(
{
@@ -71,7 +69,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/client", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{
@@ -87,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.code, 200)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
@@ -97,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+ self.assertEqual(channel.code, 404)
|