diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 192073c520..af849bd471 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
% server_and_media_id_2
),
)
+
+
+class PurgeHistoryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}"
+ self.url_status = "/_synapse/admin/v1/purge_history_status/"
+
+ def test_purge_history(self):
+ """
+ Simple test of purge history API.
+ Test only that is is possible to call, get status 200 and purge_id.
+ """
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ content={"delete_local_events": True, "purge_up_to_ts": 0},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("purge_id", channel.json_body)
+ purge_id = channel.json_body["purge_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ self.url_status + purge_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("complete", channel.json_body["status"])
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 78c48db552..cd5c60b65c 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -11,10 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from http import HTTPStatus
+from typing import Collection
+
+from parameterized import parameterized
import synapse.rest.admin
+from synapse.api.errors import Codes
from synapse.rest.client import login
from synapse.server import HomeServer
+from synapse.storage.background_updates import BackgroundUpdater
from tests import unittest
@@ -30,6 +36,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ @parameterized.expand(
+ [
+ ("GET", "/_synapse/admin/v1/background_updates/enabled"),
+ ("POST", "/_synapse/admin/v1/background_updates/enabled"),
+ ("GET", "/_synapse/admin/v1/background_updates/status"),
+ ("POST", "/_synapse/admin/v1/background_updates/start_job"),
+ ]
+ )
+ def test_requester_is_no_admin(self, method: str, url: str):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ self.register_user("user", "pass", admin=False)
+ other_user_tok = self.login("user", "pass")
+
+ channel = self.make_request(
+ method,
+ url,
+ content={},
+ access_token=other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ url = "/_synapse/admin/v1/background_updates/start_job"
+
+ # empty content
+ channel = self.make_request(
+ "POST",
+ url,
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # job_name invalid
+ channel = self.make_request(
+ "POST",
+ url,
+ content={"job_name": "unknown"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
def _register_bg_update(self):
"Adds a bg update but doesn't start it"
@@ -60,7 +120,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled, but none should be running.
self.assertDictEqual(
@@ -82,7 +142,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled, and one should be running.
self.assertDictEqual(
@@ -91,9 +151,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": True,
@@ -114,7 +176,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/enabled",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
# Disable the BG updates
@@ -124,7 +186,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": False},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": False})
# Advance a bit and get the current status, note this will finish the in
@@ -137,16 +199,18 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(
channel.json_body,
{
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": False,
@@ -162,7 +226,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# There should be no change from the previous /status response.
self.assertDictEqual(
@@ -171,9 +235,11 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 1000.0,
- "total_item_count": 100,
+ "total_item_count": (
+ BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": False,
@@ -188,7 +254,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
content={"enabled": True},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertDictEqual(channel.json_body, {"enabled": True})
@@ -199,7 +265,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"/_synapse/admin/v1/background_updates/status",
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
# Background updates should be enabled and making progress.
self.assertDictEqual(
@@ -208,11 +274,92 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
"current_updates": {
"master": {
"name": "test_update",
- "average_items_per_ms": 0.1,
+ "average_items_per_ms": 0.001,
"total_duration_ms": 2000.0,
- "total_item_count": 200,
+ "total_item_count": (
+ 2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+ ),
}
},
"enabled": True,
},
)
+
+ @parameterized.expand(
+ [
+ ("populate_stats_process_rooms", ["populate_stats_process_rooms"]),
+ (
+ "regenerate_directory",
+ [
+ "populate_user_directory_createtables",
+ "populate_user_directory_process_rooms",
+ "populate_user_directory_process_users",
+ "populate_user_directory_cleanup",
+ ],
+ ),
+ ]
+ )
+ def test_start_backround_job(self, job_name: str, updates: Collection[str]):
+ """
+ Test that background updates add to database and be processed.
+
+ Args:
+ job_name: name of the job to call with API
+ updates: collection of background updates to be started
+ """
+
+ # no background update is waiting
+ self.assertTrue(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ )
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/start_job",
+ content={"job_name": job_name},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # test that each background update is waiting now
+ for update in updates:
+ self.assertFalse(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_update(update)
+ )
+ )
+
+ self.wait_for_background_updates()
+
+ # background updates are done
+ self.assertTrue(
+ self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ )
+ )
+
+ def test_start_backround_job_twice(self):
+ """Test that add a background update twice return an error."""
+
+ # add job to database
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/background_updates/start_job",
+ content={"job_name": "populate_stats_process_rooms"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46116644ce..07077aff78 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -14,12 +14,16 @@
import json
import urllib.parse
+from http import HTTPStatus
from typing import List, Optional
from unittest.mock import Mock
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes
+from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room
from tests import unittest
@@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- json.dumps({}),
+ {},
access_token=self.other_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_room_does_not_exist(self):
@@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
url,
- json.dumps({}),
+ {},
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_room_is_not_valid(self):
@@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
url,
- json.dumps({}),
+ {},
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom is not a legal room ID",
channel.json_body["error"],
@@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("new_room_id", channel.json_body)
self.assertIn("kicked_users", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"User must be our own: @not:exist.bla",
channel.json_body["error"],
@@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self):
@@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self):
@@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"DELETE",
self.url.encode("ascii"),
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._is_blocked(self.room_id, expect=True)
self._has_no_members(self.room_id)
+ @parameterized.expand([(True,), (False,)])
+ def test_block_unknown_room(self, purge: bool) -> None:
+ """
+ We can block an unknown room. In this case, the `purge` argument
+ should be ignored.
+ """
+ room_id = "!unknown:test"
+
+ # The room isn't already in the blocked rooms table
+ self._is_blocked(room_id, expect=False)
+
+ # Request the room be blocked.
+ channel = self.make_request(
+ "DELETE",
+ f"/_synapse/admin/v1/rooms/{room_id}",
+ {"block": True, "purge": purge},
+ access_token=self.admin_user_tok,
+ )
+
+ # The room is now blocked.
+ self.assertEqual(
+ HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"]
+ )
+ self._is_blocked(room_id)
+
def test_shutdown_room_consent(self):
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
@@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
json.dumps({"history_visibility": "world_readable"}),
access_token=self.other_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Test that room is not purged
with self.assertRaises(AssertionError):
@@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("new_room_id", channel.json_body)
self.assertIn("failed_to_kick_users", channel.json_body)
@@ -418,17 +447,616 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+ url = "events?timeout=0&room_id=" + room_id
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+
+class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ events.register_servlets,
+ room.register_servlets,
+ room.register_deprecated_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.event_creation_handler = hs.get_event_creation_handler()
+ hs.config.consent.user_consent_version = "1"
+
+ consent_uri_builder = Mock()
+ consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+ self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # Mark the admin user as having consented
+ self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = f"/_synapse/admin/v2/rooms/{self.room_id}"
+ self.url_status_by_room_id = (
+ f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status"
+ )
+ self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/"
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_requester_is_no_admin(self, method: str, url: str):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request(
+ method,
+ url % self.room_id,
+ content={},
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
+ ]
+ )
+ def test_room_does_not_exist(self, method: str, url: str):
+ """
+ Check that unknown rooms/server return error 404.
+ """
+
+ channel = self.make_request(
+ method,
+ url % "!unknown:test",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @parameterized.expand(
+ [
+ ("DELETE", "/_synapse/admin/v2/rooms/%s"),
+ ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
+ ]
+ )
+ def test_room_is_not_valid(self, method: str, url: str):
+ """
+ Check that invalid room names, return an error 400.
+ """
+
+ channel = self.make_request(
+ method,
+ url % "invalidroom",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
+ )
+
+ def test_new_room_user_does_not_exist(self):
+ """
+ Tests that the user ID must be from local server but it does not have to exist.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": "@unknown:test"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ def test_new_room_user_is_not_local(self):
+ """
+ Check that only local users can create new room to move members.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": "@not:exist.bla"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ "User must be our own: @not:exist.bla",
+ channel.json_body["error"],
)
+ def test_block_is_not_bool(self):
+ """
+ If parameter `block` is not boolean, return an error
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"block": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_purge_is_not_bool(self):
+ """
+ If parameter `purge` is not boolean, return an error
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"purge": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ def test_delete_expired_status(self):
+ """Test that the task status is removed after expiration."""
+
+ # first task, do not purge, that we can create a second task
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"purge": False},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id1 = channel.json_body["delete_id"]
+
+ # go ahead
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ # second task
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id2 = channel.json_body["delete_id"]
+
+ # get status
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(2, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual("complete", channel.json_body["results"][1]["status"])
+ self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"])
+ self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"])
+
+ # get status after more than clearing time for first task
+ # second task is not cleared
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+ self.assertEqual("complete", channel.json_body["results"][0]["status"])
+ self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+
+ # get status after more than clearing time for all tasks
+ self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_delete_same_room_twice(self):
+ """Test that the call for delete a room at second time gives an exception."""
+
+ body = {"new_room_user_id": self.admin_user}
+
+ # first call to delete room
+ # and do not wait for finish the task
+ first_channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content=body,
+ access_token=self.admin_user_tok,
+ await_result=False,
+ )
+
+ # second call to delete room
+ second_channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content=body,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body
+ )
+ self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"])
+ self.assertEqual(
+ f"History purge already in progress for {self.room_id}",
+ second_channel.json_body["error"],
+ )
+
+ # get result of first call
+ first_channel.await_result()
+ self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body)
+ self.assertIn("delete_id", first_channel.json_body)
+
+ # check status after finish the task
+ self._test_result(
+ first_channel.json_body["delete_id"],
+ self.other_user,
+ expect_new_room=True,
+ )
+
+ def test_purge_room_and_block(self):
+ """Test to purge a room and block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": True, "purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_purge_room_and_not_block(self):
+ """Test to purge a room and do not block it.
+ Members will not be moved to a new room and will not receive a message.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": False, "purge": True},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=False)
+ self._has_no_members(self.room_id)
+
+ def test_block_room_and_not_purge(self):
+ """Test to block a room without purging it.
+ Members will not be moved to a new room and will not receive a message.
+ The room will not be purged.
+ """
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Test that room is not blocked
+ self._is_blocked(self.room_id, expect=False)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ channel = self.make_request(
+ "DELETE",
+ self.url.encode("ascii"),
+ content={"block": True, "purge": False},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user)
+
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+ self._is_blocked(self.room_id, expect=True)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_consent(self):
+ """Test that we can shutdown rooms with local users who have not
+ yet accepted the privacy policy. This used to fail when we tried to
+ force part the user from the old room.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Assert one user in room
+ users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+ self.assertEqual([self.other_user], users_in_room)
+
+ # Enable require consent to send events
+ self.event_creation_handler._block_events_without_consent_error = "Error"
+
+ # Assert that the user is getting consent error
+ self.helper.send(
+ self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+ )
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+ user_id=self.other_user,
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ def test_shutdown_room_block_peek(self):
+ """Test that a world_readable room can no longer be peeked into after
+ it has been shut down.
+ Members will be moved to a new room and will receive a message.
+ """
+ self.event_creation_handler._block_events_without_consent_error = None
+
+ # Enable world readable
+ url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+ channel = self.make_request(
+ "PUT",
+ url.encode("ascii"),
+ content={"history_visibility": "world_readable"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # Test that room is not purged
+ with self.assertRaises(AssertionError):
+ self._is_purged(self.room_id)
+
+ # Assert one user in room
+ self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+ # Test that the admin can still send shutdown
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ content={"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertIn("delete_id", channel.json_body)
+ delete_id = channel.json_body["delete_id"]
+
+ self._test_result(delete_id, self.other_user, expect_new_room=True)
+
+ channel = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertEqual(1, len(channel.json_body["results"]))
+
+ # Test that member has moved to new room
+ self._is_member(
+ room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"],
+ user_id=self.other_user,
+ )
+
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
+ # Assert we can no longer peek into the room
+ self._assert_peek(self.room_id, expect_code=403)
+
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+ """Assert that the room is blocked or not"""
+ d = self.store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _has_no_members(self, room_id: str) -> None:
+ """Assert there is now no longer anyone in the room"""
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertEqual([], users_in_room)
+
+ def _is_member(self, room_id: str, user_id: str) -> None:
+ """Test that user is member of the room"""
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertIn(user_id, users_in_room)
+
+ def _is_purged(self, room_id: str) -> None:
+ """Test that the following tables have been purged of all rows related to the room."""
+ for table in PURGE_TABLES:
+ count = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table=table,
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
+
+ def _assert_peek(self, room_id: str, expect_code: int) -> None:
+ """Assert that the admin user can (or cannot) peek into the room."""
+
+ url = f"rooms/{room_id}/initialSync"
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok
+ )
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
url = "events?timeout=0&room_id=" + room_id
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok
)
+ self.assertEqual(expect_code, channel.code, msg=channel.json_body)
+
+ def _test_result(
+ self,
+ delete_id: str,
+ kicked_user: str,
+ expect_new_room: bool = False,
+ ) -> None:
+ """
+ Test that the result is the expected.
+ Uses both APIs (status by room_id and delete_id)
+
+ Args:
+ delete_id: id of this purge
+ kicked_user: a user_id which is kicked from the room
+ expect_new_room: if we expect that a new room was created
+ """
+
+ # get information by room_id
+ channel_room_id = self.make_request(
+ "GET",
+ self.url_status_by_room_id,
+ access_token=self.admin_user_tok,
+ )
self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
+ HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body
)
+ self.assertEqual(1, len(channel_room_id.json_body["results"]))
+ self.assertEqual(
+ delete_id, channel_room_id.json_body["results"][0]["delete_id"]
+ )
+
+ # get information by delete_id
+ channel_delete_id = self.make_request(
+ "GET",
+ self.url_status_by_delete_id + delete_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(
+ HTTPStatus.OK,
+ channel_delete_id.code,
+ msg=channel_delete_id.json_body,
+ )
+
+ # test values that are the same in both responses
+ for content in [
+ channel_room_id.json_body["results"][0],
+ channel_delete_id.json_body,
+ ]:
+ self.assertEqual("complete", content["status"])
+ self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0])
+ self.assertIn("failed_to_kick_users", content["shutdown_room"])
+ self.assertIn("local_aliases", content["shutdown_room"])
+ self.assertNotIn("error", content)
+
+ if expect_new_room:
+ self.assertIsNotNone(content["shutdown_room"]["new_room_id"])
+ else:
+ self.assertIsNone(content["shutdown_room"]["new_room_id"])
class RoomTestCase(unittest.HomeserverTestCase):
@@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
# Check request completed successfully
- self.assertEqual(200, int(channel.code), msg=channel.json_body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that response json body contains a "rooms" key
self.assertTrue(
@@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(
- 200, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue("rooms" in channel.json_body)
for r in channel.json_body["rooms"]:
@@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self):
"""Test the correct attributes for a room are returned"""
@@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
url.encode("ascii"),
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Check that rooms were returned
self.assertTrue("rooms" in channel.json_body)
@@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
{"room_id": room_id},
access_token=admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Set this new alias as the canonical alias for this room
self.helper.send_state(
@@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.second_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self):
@@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self):
@@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self):
@@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"This endpoint can only be used with local users",
channel.json_body["error"],
@@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"])
def test_room_is_not_valid(self):
@@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
"invalidroom was not legal room ID or room alias",
channel.json_body["error"],
@@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self):
@@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self):
@@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.admin_user_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
# Join user to room.
@@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self):
@@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
url,
- content=body.encode(encoding="utf_8"),
+ content=body,
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["room_id"])
# Validate if user is a member of the room
@@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/joined_rooms",
access_token=self.second_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self):
@@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=tok,
)
- self.assertEquals(
- 403, int(channel.result["code"]), msg=channel.result["body"]
- )
+ self.assertEquals(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self):
@@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
% (room_id, events[midway]["event_id"]),
access_token=self.admin_user_tok,
)
- self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(200, channel.code, msg=channel.json_body)
self.assertEquals(
channel.json_body["event"]["event_id"], events[midway]["event_id"]
)
@@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
@@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room (we should have received an
# invite) and can ban a user.
@@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and ban a user.
self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
@@ -1595,13 +2219,241 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
#
# (Note we assert the error message to ensure that it's not denied for
# some other reason)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(
channel.json_body["error"],
"No local admin user in room with power to update power levels.",
)
+class BlockRoomTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self._store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok
+ )
+ self.url = "/_synapse/admin/v1/rooms/%s/block"
+
+ @parameterized.expand([("PUT",), ("GET",)])
+ def test_requester_is_no_admin(self, method: str):
+ """If the user is not a server admin, an error 403 is returned."""
+
+ channel = self.make_request(
+ method,
+ self.url % self.room_id,
+ content={},
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @parameterized.expand([("PUT",), ("GET",)])
+ def test_room_is_not_valid(self, method: str):
+ """Check that invalid room names, return an error 400."""
+
+ channel = self.make_request(
+ method,
+ self.url % "invalidroom",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "invalidroom is not a legal room ID",
+ channel.json_body["error"],
+ )
+
+ def test_block_is_not_valid(self):
+ """If parameter `block` is not valid, return an error."""
+
+ # `block` is not valid
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": "NotBool"},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+ # `block` is not set
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no content is send
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ def test_block_room(self):
+ """Test that block a room is successful."""
+
+ def _request_and_test_block_room(room_id: str) -> None:
+ self._is_blocked(room_id, expect=False)
+ channel = self.make_request(
+ "PUT",
+ self.url % room_id,
+ content={"block": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self._is_blocked(room_id, expect=True)
+
+ # known internal room
+ _request_and_test_block_room(self.room_id)
+
+ # unknown internal room
+ _request_and_test_block_room("!unknown:test")
+
+ # unknown remote room
+ _request_and_test_block_room("!unknown:remote")
+
+ def test_block_room_twice(self):
+ """Test that block a room that is already blocked is successful."""
+
+ self._is_blocked(self.room_id, expect=False)
+ for _ in range(2):
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": True},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self._is_blocked(self.room_id, expect=True)
+
+ def test_unblock_room(self):
+ """Test that unblock a room is successful."""
+
+ def _request_and_test_unblock_room(room_id: str) -> None:
+ self._block_room(room_id)
+
+ channel = self.make_request(
+ "PUT",
+ self.url % room_id,
+ content={"block": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self._is_blocked(room_id, expect=False)
+
+ # known internal room
+ _request_and_test_unblock_room(self.room_id)
+
+ # unknown internal room
+ _request_and_test_unblock_room("!unknown:test")
+
+ # unknown remote room
+ _request_and_test_unblock_room("!unknown:remote")
+
+ def test_unblock_room_twice(self):
+ """Test that unblock a room that is not blocked is successful."""
+
+ self._block_room(self.room_id)
+ for _ in range(2):
+ channel = self.make_request(
+ "PUT",
+ self.url % self.room_id,
+ content={"block": False},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self._is_blocked(self.room_id, expect=False)
+
+ def test_get_blocked_room(self):
+ """Test get status of a blocked room"""
+
+ def _request_blocked_room(room_id: str) -> None:
+ self._block_room(room_id)
+
+ channel = self.make_request(
+ "GET",
+ self.url % room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertTrue(channel.json_body["block"])
+ self.assertEqual(self.other_user, channel.json_body["user_id"])
+
+ # known internal room
+ _request_blocked_room(self.room_id)
+
+ # unknown internal room
+ _request_blocked_room("!unknown:test")
+
+ # unknown remote room
+ _request_blocked_room("!unknown:remote")
+
+ def test_get_unblocked_room(self):
+ """Test get status of a unblocked room"""
+
+ def _request_unblocked_room(room_id: str) -> None:
+ self._is_blocked(room_id, expect=False)
+
+ channel = self.make_request(
+ "GET",
+ self.url % room_id,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body["block"])
+ self.assertNotIn("user_id", channel.json_body)
+
+ # known internal room
+ _request_unblocked_room(self.room_id)
+
+ # unknown internal room
+ _request_unblocked_room("!unknown:test")
+
+ # unknown remote room
+ _request_unblocked_room("!unknown:remote")
+
+ def _is_blocked(self, room_id: str, expect: bool = True) -> None:
+ """Assert that the room is blocked or not"""
+ d = self._store.is_room_blocked(room_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertIsNone(self.get_success(d))
+
+ def _block_room(self, room_id: str) -> None:
+ """Block a room in database"""
+ self.get_success(self._store.block_room(room_id, self.other_user))
+ self._is_blocked(room_id, expect=True)
+
+
PURGE_TABLES = [
"current_state_events",
"event_backward_extremities",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 25e8d6cf27..5011e54563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.admin_user, device_id=None, valid_until_ms=None
)
)
self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.other_user, device_id=None, valid_until_ms=None
)
)
@@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.other_user
)
- def test_no_auth(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get information of an user without authentication.
"""
- channel = self.make_request("POST", self.url)
+ channel = self.make_request(method, self.url)
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_not_admin(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_requester_is_not_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
- channel = self.make_request("POST", self.url, access_token=other_user_token)
+ channel = self.make_request(method, self.url, access_token=other_user_token)
self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ @parameterized.expand(["POST", "DELETE"])
+ def test_user_is_not_local(self, method: str):
"""
Tests that shadow-banning for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
- channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+ channel = self.make_request(method, url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self):
@@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned)
+ # Un-shadow-ban the user.
+ channel = self.make_request(
+ "DELETE", self.url, access_token=self.admin_user_tok
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ # Ensure the user is no longer shadow-banned (and the cache was cleared).
+ result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+ self.assertFalse(result.shadow_banned)
+
class RateLimitTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index e2fcbdc63a..8552671431 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -598,7 +598,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
refresh_response.json_body["refresh_token"],
)
- @override_config({"access_token_lifetime": "1m"})
+ @override_config({"refreshable_access_token_lifetime": "1m"})
def test_refresh_token_expiration(self):
"""
The access token should have some time as specified in the config.
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index b9e3602552..249808b031 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
@@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def test_get_does_include_msc3244_fields_when_enabled(self):
access_token = self.get_success(
- self.auth_handler.get_access_token_for_user_id(
+ self.auth_handler.create_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..aca03afd0e 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import directory, login, room
+from synapse.server import HomeServer
from synapse.types import RoomAlias
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["require_membership_for_aliases"] = True
@@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return self.hs
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ """Create two local users and access tokens for them.
+ One of them creates a room."""
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test")
- def test_state_event_not_in_room(self):
+ def test_state_event_not_in_room(self) -> None:
self.ensure_user_left_room()
- self.set_alias_via_state_event(403)
+ self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
- def test_directory_endpoint_not_in_room(self):
+ def test_directory_endpoint_not_in_room(self) -> None:
self.ensure_user_left_room()
- self.set_alias_via_directory(403)
+ self.set_alias_via_directory(HTTPStatus.FORBIDDEN)
- def test_state_event_in_room_too_long(self):
+ def test_state_event_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_state_event(400, alias_length=256)
+ self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256)
- def test_directory_in_room_too_long(self):
+ def test_directory_in_room_too_long(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_directory(400, alias_length=256)
+ self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256)
@override_config({"default_room_version": 5})
- def test_state_event_user_in_v5_room(self):
+ def test_state_event_user_in_v5_room(self) -> None:
"""Test that a regular user can add alias events before room v6"""
self.ensure_user_joined_room()
- self.set_alias_via_state_event(200)
+ self.set_alias_via_state_event(HTTPStatus.OK)
@override_config({"default_room_version": 6})
- def test_state_event_v6_room(self):
+ def test_state_event_v6_room(self) -> None:
"""Test that a regular user can *not* add alias events from room v6"""
self.ensure_user_joined_room()
- self.set_alias_via_state_event(403)
+ self.set_alias_via_state_event(HTTPStatus.FORBIDDEN)
- def test_directory_in_room(self):
+ def test_directory_in_room(self) -> None:
self.ensure_user_joined_room()
- self.set_alias_via_directory(200)
+ self.set_alias_via_directory(HTTPStatus.OK)
- def test_room_creation_too_long(self):
+ def test_room_creation_too_long(self) -> None:
url = "/_matrix/client/r0/createRoom"
# We use deliberately a localpart under the length threshold so
@@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
- def test_room_creation(self):
+ def test_room_creation(self) -> None:
url = "/_matrix/client/r0/createRoom"
# Check with an alias of allowed length. There should already be
@@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_deleting_alias_via_directory(self) -> None:
+ # Add an alias for the room. We must be joined to do so.
+ self.ensure_user_joined_room()
+ alias = self.set_alias_via_directory(HTTPStatus.OK)
+
+ # Then try to remove the alias
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_deleting_nonexistant_alias(self) -> None:
+ # Check that no alias exists
+ alias = "#potato:test"
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
+
+ # Then try to remove the alias
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body)
- def set_alias_via_state_event(self, expected_code, alias_length=5):
+ def set_alias_via_state_event(
+ self, expected_code: HTTPStatus, alias_length: int = 5
+ ) -> None:
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
self.room_id,
self.hs.hostname,
@@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, expected_code, channel.result)
- def set_alias_via_directory(self, expected_code, alias_length=5):
- url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
+ def set_alias_via_directory(
+ self, expected_code: HTTPStatus, alias_length: int = 5
+ ) -> str:
+ alias = self.random_alias(alias_length)
+ url = "/_matrix/client/r0/directory/room/%s" % alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
"PUT", url, request_data, access_token=self.user_tok
)
self.assertEqual(channel.code, expected_code, channel.result)
+ return alias
- def random_alias(self, length):
+ def random_alias(self, length: int) -> str:
return RoomAlias(random_string(length), self.hs.hostname).to_string()
- def ensure_user_left_room(self):
+ def ensure_user_left_room(self) -> None:
self.ensure_membership("leave")
- def ensure_user_joined_room(self):
+ def ensure_user_joined_room(self) -> None:
self.ensure_membership("join")
- def ensure_membership(self, membership):
+ def ensure_membership(self, membership: str) -> None:
try:
if membership == "leave":
self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a63f04bd41..19f5e46537 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fΓΆ&=o"')]
# (possibly experimental) login flows we expect to appear in the list after the normal
# ones
-ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+ADDITIONAL_LOGIN_FLOWS = [
+ {"type": "m.login.application_service"},
+ {"type": "uk.half-shot.msc2778.login.application_service"},
+]
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -812,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
jwt_secret = "secret"
jwt_algorithm = "HS256"
+ base_config = {
+ "enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ }
- def make_homeserver(self, reactor, clock):
- self.hs = self.setup_test_homeserver()
- self.hs.config.jwt.jwt_enabled = True
- self.hs.config.jwt.jwt_secret = self.jwt_secret
- self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+
+ # If jwt_config has been defined (eg via @override_config), don't replace it.
+ if config.get("jwt_config") is None:
+ config["jwt_config"] = self.base_config
+
+ return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -876,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
- @override_config(
- {
- "jwt_config": {
- "jwt_enabled": True,
- "secret": jwt_secret,
- "algorithm": jwt_algorithm,
- "issuer": "test-issuer",
- }
- }
- )
+ @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self):
"""Test validating the issuer claim."""
# A valid issuer.
@@ -916,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
- @override_config(
- {
- "jwt_config": {
- "jwt_enabled": True,
- "secret": jwt_secret,
- "algorithm": jwt_algorithm,
- "audiences": ["test-audience"],
- }
- }
- )
+ @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self):
"""Test validating the audience claim."""
# A valid audience.
@@ -959,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
+ def test_login_default_sub(self):
+ """Test reading user ID from the default subject claim."""
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
+ def test_login_custom_sub(self):
+ """Test reading user ID from a custom subject claim."""
+ channel = self.jwt_login({"username": "frog"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@frog:test")
+
def test_login_no_token(self):
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1021,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
]
)
- def make_homeserver(self, reactor, clock):
- self.hs = self.setup_test_homeserver()
- self.hs.config.jwt.jwt_enabled = True
- self.hs.config.jwt.jwt_secret = self.jwt_pubkey
- self.hs.config.jwt.jwt_algorithm = "RS256"
- return self.hs
+ def default_config(self):
+ config = super().default_config()
+ config["jwt_config"] = {
+ "enabled": True,
+ "secret": self.jwt_pubkey,
+ "algorithm": "RS256",
+ }
+ return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 78c2fb86b9..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
# Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob")
@@ -91,6 +94,49 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_invalid_event(self):
+ """Test that we deny relations on non-existant events"""
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ EventTypes.Message,
+ parent_id="foo",
+ content={"body": "foo", "msgtype": "m.text"},
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ # Unless that event is referenced from another event!
+ self.get_success(
+ self.hs.get_datastore().db_pool.simple_insert(
+ table="event_relations",
+ values={
+ "event_id": "bar",
+ "relates_to_id": "foo",
+ "relation_type": RelationTypes.THREAD,
+ },
+ desc="test_deny_invalid_event",
+ )
+ )
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ parent_id="foo",
+ content={"body": "foo", "msgtype": "m.text"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ def test_deny_invalid_room(self):
+ """Test that we deny relations on non-existant events"""
+ # Create another room and send a message in it.
+ room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+ parent_id = res["event_id"]
+
+ # Attempt to send an annotation to that event.
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_deny_double_react(self):
"""Test that we deny relations on membership events"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
@@ -99,6 +145,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_forked_thread(self):
+ """It is invalid to start a thread off a thread."""
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo"},
+ parent_id=self.parent_id,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ parent_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo"},
+ parent_id=parent_id,
+ )
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -703,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def test_unknown_relations(self):
+ """Unknown relations should be accepted."""
+ channel = self._send_relation("m.relation.test", "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEquals(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # When bundling the unknown relation is not included.
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+ # But unknown relations can be directly queried.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
def _send_relation(
self,
relation_type: str,
@@ -749,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token = self.login(localpart, "abc123")
return user_id, access_token
+
+ def test_background_update(self):
+ """Test the event_arbitrary_relations background update."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_good = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_bad = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_event_id = channel.json_body["event_id"]
+
+ # Clean-up the table as if the inserts did not happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="event_relations",
+ column="event_id",
+ iterable=(annotation_event_id_bad, thread_event_id),
+ keyvalues={},
+ desc="RelationsTestCase.test_background_update",
+ )
+ )
+
+ # Only the "good" annotation should be found.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good],
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # The "good" annotation and the thread should be found, but not the "bad"
+ # annotation.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertCountEqual(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good, thread_event_id],
+ )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 376853fd65..10a4a4dc5e 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -25,7 +25,12 @@ from urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
@@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id
+class RelationsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["experimental_features"] = {"msc3440_enabled": True}
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ self.second_user_id = self.register_user("second", "test")
+ self.second_tok = self.login("second", "test")
+ self.helper.join(
+ room=self.room_id, user=self.second_user_id, tok=self.second_tok
+ )
+
+ self.third_user_id = self.register_user("third", "test")
+ self.third_tok = self.login("third", "test")
+ self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+
+ # An initial event with a relation from second user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 1"},
+ tok=self.tok,
+ )
+ self.event_id_1 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": self.event_id_1,
+ "key": "π",
+ }
+ },
+ tok=self.second_tok,
+ )
+
+ # Another event with a relation from third user.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "Message 2"},
+ tok=self.tok,
+ )
+ self.event_id_2 = res["event_id"]
+ self.helper.send_event(
+ room_id=self.room_id,
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": self.event_id_2,
+ }
+ },
+ tok=self.third_tok,
+ )
+
+ # An event with no relations.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "No relations"},
+ tok=self.tok,
+ )
+
+ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+ """Make a request to /messages with a filter, returns the chunk of events."""
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ return channel.json_body["chunk"]
+
+ def test_filter_relation_senders(self):
+ # Messages which second user reacted to.
+ filter = {"io.element.relation_senders": [self.second_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+ # Messages which third user reacted to.
+ filter = {"io.element.relation_senders": [self.third_user_id]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+ # Messages which either user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id, self.third_user_id]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_type(self):
+ # Messages which have annotations.
+ filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+ # Messages which have references.
+ filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_2)
+
+ # Messages which have either annotations or references.
+ filter = {
+ "io.element.relation_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 2, chunk)
+ self.assertCountEqual(
+ [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
+ )
+
+ def test_filter_relation_senders_and_type(self):
+ # Messages which second user reacted to.
+ filter = {
+ "io.element.relation_senders": [self.second_user_id],
+ "io.element.relation_types": [RelationTypes.ANNOTATION],
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(len(chunk), 1, chunk)
+ self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+
+
class ContextTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,10 +19,21 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ AnyStr,
+ Dict,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Tuple,
+ overload,
+)
from unittest.mock import patch
import attr
+from typing_extensions import Literal
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -45,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site)
auth_user_id = attr.ib()
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: Literal[200] = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> str:
+ ...
+
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: int = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> Optional[str]:
+ ...
+
def create_room_as(
self,
room_creator: Optional[str] = None,
@@ -53,10 +90,8 @@ class RestHelper:
tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
- ) -> str:
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ ) -> Optional[str]:
"""
Create a room.
@@ -99,6 +134,8 @@ class RestHelper:
if expect_code == 200:
return channel.json_body["room_id"]
+ else:
+ return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -168,7 +205,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
- expect_errcode: str = None,
+ expect_errcode: Optional[str] = None,
) -> None:
"""
Send a membership state event into a room.
@@ -227,9 +264,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if body is None:
body = "body_text_here"
@@ -254,9 +289,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -418,7 +451,7 @@ class RestHelper:
path,
content=image_data,
access_token=tok,
- custom_headers=[(b"Content-Length", str(image_length))],
+ custom_headers=[("Content-Length", str(image_length))],
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
@@ -503,7 +536,7 @@ class RestHelper:
went.
"""
- cookies = {}
+ cookies: Dict[str, str] = {}
# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
@@ -625,7 +658,13 @@ class RestHelper:
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
- location = channel.headers.getRawHeaders("Location")[0]
+ def get_location(channel: FakeChannel) -> str:
+ location_values = channel.headers.getRawHeaders("Location")
+ # Keep mypy happy by asserting that location_values is nonempty
+ assert location_values
+ return location_values[0]
+
+ location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
@@ -639,7 +678,7 @@ class RestHelper:
assert channel.code == 302
channel.extract_cookies(cookies)
- return channel.headers.getRawHeaders("Location")[0]
+ return get_location(channel)
def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
|