diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 03f2112b07..aaa488bced 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -28,7 +28,6 @@ from tests import unittest
class DeviceRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
class DevicesRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 233eba3516..f189b07769 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
"""
Try to get an event report without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, {})
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
"""
Try to get event report without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
+ channel = self.make_request("GET", self.url, {})
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertIn("room_id", content["event_json"])
self.assertIn("sender", content["event_json"])
self.assertIn("content", content["event_json"])
+
+
+class DeleteEventReportTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._store = hs.get_datastores().main
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ # create report
+ event_id = self.get_success(
+ self._store.add_event_report(
+ "room_id",
+ "event_id",
+ self.other_user,
+ "this makes me sad",
+ {},
+ self.clock.time_msec(),
+ )
+ )
+
+ self.url = f"/_synapse/admin/v1/event_reports/{event_id}"
+
+ def test_no_auth(self) -> None:
+ """
+ Try to delete event report without authentication.
+ """
+ channel = self.make_request("DELETE", self.url)
+
+ self.assertEqual(401, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self) -> None:
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.other_user_tok,
+ )
+
+ self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_success(self) -> None:
+ """
+ Testing delete a report.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual({}, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ # check that report was deleted
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_invalid_report_id(self) -> None:
+ """
+ Testing that an invalid `report_id` returns a 400.
+ """
+
+ # `report_id` is negative
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/-123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is a non-numerical string
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/abcdef",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ # `report_id` is undefined
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual(
+ "The report_id parameter must be a string representing a positive integer.",
+ channel.json_body["error"],
+ )
+
+ def test_report_id_not_found(self) -> None:
+ """
+ Testing that a not existing `report_id` returns a 404.
+ """
+
+ channel = self.make_request(
+ "DELETE",
+ "/_synapse/admin/v1/event_reports/123",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+ self.assertEqual("Event report not found", channel.json_body["error"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..6d04911d67 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
+from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import login, profile, room
-from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.util import Clock
@@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -213,7 +211,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
- self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+ self.url = "/_synapse/admin/v1/media/delete"
+ self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
# Move clock up to somewhat realistic time
self.reactor.advance(1000000000)
@@ -332,11 +331,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- def test_delete_media_never_accessed(self) -> None:
+ @parameterized.expand([(True,), (False,)])
+ def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
"""
Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts`
"""
+ url = self.legacy_url if use_legacy_url else self.url
# upload and do not access
server_and_media_id = self._create_media()
@@ -351,7 +352,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
now_ms = self.clock.time_msec()
channel = self.make_request(
"POST",
- self.url + "?before_ts=" + str(now_ms),
+ url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -591,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -721,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
@@ -818,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 453a6e979c..9dbb778679 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
room.register_servlets,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..28b999573e 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -28,7 +28,6 @@ from tests.unittest import override_config
class ServerNoticeTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -558,7 +557,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.
Args
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..4b8f889a71 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,8 +28,8 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
+from synapse.media.filepath import MediaFilePaths
from synapse.rest.client import devices, login, logout, profile, register, room, sync
-from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage_controllers = self.hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage_controllers.persistence.persist_event(event, context))
+ context = self.get_success(unpersisted_context.persist(event))
+
+ self.get_success(persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- async def check_username(username: str) -> bool:
- if username == "allowed":
- return True
+ async def check_username(
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
+ inhibit_user_in_use_error: bool = False,
+ ) -> None:
+ if localpart == "allowed":
+ return
raise SynapseError(
400,
"User ID already taken.",
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
)
handler = self.hs.get_registration_handler()
- handler.check_username = check_username
+ handler.check_username = check_username # type: ignore[assignment]
def test_username_available(self) -> None:
"""
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..2b05dffc7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
-
servlets = [
account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
class DeactivateTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
class WhoamiTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
-
servlets = [
account.register_servlets,
login.register_servlets,
@@ -1193,7 +1189,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
return {}
# Register a mock that will return the expected result depending on the remote.
- self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+ self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
# Check that we've got the correct response from the client-side endpoint.
self._test_status(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..0d8fe77b88 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
@@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
super().__init__(hs)
self.recaptcha_attempts: List[Tuple[dict, str]] = []
+ def is_enabled(self) -> bool:
+ return True
+
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
class FallbackAuthTests(unittest.HomeserverTestCase):
-
servlets = [
auth.register_servlets,
register.register_servlets,
@@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_registration_captcha"] = True
@@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
- # Now try to exchange the login token
- channel = make_request(
- self.hs.get_reactor(),
- self.site,
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- # It should have failed
- self.assertEqual(channel.code, 403)
+ # Now try to exchange the login token, it should fail.
+ self.helper.login_via_token(login_token, 403)
@override_config(
{
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index d1751e1557..c16e8d43f4 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -26,7 +26,6 @@ from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
capabilities.register_servlets,
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index b1ca81a911..bb845179d3 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
hijack_auth = False
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["form_secret"] = "123abc"
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 7a88aa2cda..6490e883bf 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config
class DirectoryTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 9fa1f82dfe..f31ebc8021 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -26,7 +26,6 @@ from tests import unittest
class EphemeralMessageTestCase(unittest.HomeserverTestCase):
-
user_id = "@user:test"
servlets = [
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a9b7db9db2..54df2a252c 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_registration_captcha"] = False
config["enable_registration"] = True
@@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
@@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..91678abf13 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.HomeserverTestCase):
-
user_id = "@apple:test"
hijack_auth = True
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
@@ -63,14 +62,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
- self.hs.is_mine = lambda target_user: False
+ self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.hs.is_mine = _is_mine
+ self.hs.is_mine = _is_mine # type: ignore[assignment]
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 741fecea77..8ee5489057 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -14,12 +14,21 @@
from http import HTTPStatus
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
+from signedjson.sign import sign_json
+
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import keys, login
+from synapse.types import JsonDict
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
+from tests.unittest import override_config
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertIn(bob, channel.json_body["device_keys"])
+
+ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+ # We only generate a master key to simplify the test.
+ master_signing_key = generate_signing_key(device_id)
+ master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+ return {
+ "master_key": sign_json(
+ {
+ "user_id": user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_verify_key: master_verify_key},
+ },
+ user_id,
+ master_signing_key,
+ ),
+ }
+
+ def test_device_signing_with_uia(self) -> None:
+ """Device signing key upload requires UIA."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ content["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config({"ui_auth": {"session_timeout": "15m"}})
+ def test_device_signing_with_uia_session_timeout(self) -> None:
+ """Device signing key upload requires UIA buy passes with grace period."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ content = self.make_device_keys(alice_id, device_id)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ content,
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ @override_config(
+ {
+ "experimental_features": {"msc3967_enabled": True},
+ "ui_auth": {"session_timeout": "15s"},
+ }
+ )
+ def test_device_signing_with_msc3967(self) -> None:
+ """Device signing key follows MSC3967 behaviour when enabled."""
+ password = "wonderland"
+ device_id = "ABCDEFGHI"
+ alice_id = self.register_user("alice", password)
+ alice_token = self.login("alice", password, device_id=device_id)
+
+ keys1 = self.make_device_keys(alice_id, device_id)
+
+ # Initial request should succeed as no existing keys are present.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys1,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ keys2 = self.make_device_keys(alice_id, device_id)
+
+ # Subsequent request should require UIA as keys already exist even though session_timeout is set.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys2,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+
+ # Grab the session
+ session = channel.json_body["session"]
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ # add UI auth
+ keys2["auth"] = {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": alice_id},
+ "password": password,
+ "session": session,
+ }
+
+ # Request should complete
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys2,
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ff5baa9f0a..62acf4f44e 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [
class LoginRestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
class CASTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
]
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index 6aedc1a11c..b8187db982 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
admin.register_servlets,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..dcbb125a3b 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -35,15 +35,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
servlets = [presence.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
- presence_handler = Mock(spec=PresenceHandler)
- presence_handler.set_state.return_value = make_awaitable(None)
+ self.presence_handler = Mock(spec=PresenceHandler)
+ self.presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
federation_client=Mock(),
- presence_handler=presence_handler,
+ presence_handler=self.presence_handler,
)
return hs
@@ -61,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
+ self.assertEqual(self.presence_handler.set_state.call_count, 1)
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None:
@@ -76,4 +75,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
+ self.assertEqual(self.presence_handler.set_state.call_count, 0)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 8de5a342ae..27c93ad761 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -30,7 +30,6 @@ from tests import unittest
class ProfileTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
@@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..b228dba861 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
login.register_servlets,
register.register_servlets,
@@ -151,7 +150,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self) -> None:
- self.hs.config.key.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = b"test"
self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
class AccountValidityTestCase(unittest.HomeserverTestCase):
-
servlets = [
register.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
-
servlets = [
register.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
-
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -1166,12 +1162,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+ self.hs.config.account_validity.account_validity_startup_job_max_delta = (
+ self.max_delta
+ )
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ assert res is not None
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8a6911d5e..fbbbcb23f1 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,7 +30,6 @@ from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
-from tests.unittest import override_config
class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_edit(self) -> None:
"""Test that a simple edit works."""
-
+ orig_body = {"body": "Hi!", "msgtype": "m.text"}
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
edit_event_content = {
"msgtype": "m.text",
@@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
+ self.assertEqual(channel.json_body["content"], orig_body)
self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
# Request the room messages.
@@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# Request the room context.
- # /context should return the edited event.
+ # /context should return the event.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
@@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase):
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
# Request sync, but limit the timeline so it becomes limited (and includes
# bundled aggregations).
@@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase):
edit_event_content,
)
- @override_config({"experimental_features": {"msc3925_inhibit_edit": True}})
- def test_edit_inhibit_replace(self) -> None:
- """
- If msc3925_inhibit_edit is enabled, then the original event should not be
- replaced.
- """
-
- new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- edit_event_content = {
- "msgtype": "m.text",
- "body": "foo",
- "m.new_content": new_body,
- }
- channel = self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- content=edit_event_content,
- )
- edit_event_id = channel.json_body["event_id"]
-
- # /context should return the *original* event.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
- self._assert_edit_bundle(
- channel.json_body["event"], edit_event_id, edit_event_content
- )
-
def test_multi_edit(self) -> None:
"""Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled.
"""
-
+ orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"}
self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
@@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
@@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase):
def test_edit_edit(self) -> None:
"""Test that an edit cannot be edited."""
+ orig_body = {"body": "Hi!", "msgtype": "m.text"}
new_body = {"msgtype": "m.text", "body": "Initial edit"}
edit_event_content = {
"msgtype": "m.text",
@@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
- )
+ self.assertEqual(channel.json_body["content"], orig_body)
# The relations information should not include the edit to the edit.
self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
- # /context should return the event updated for the *first* edit
+ # /context should return the bundled edit for the *first* edit
# (The edit to the edit should be ignored.)
channel = self.make_request(
"GET",
@@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase):
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(channel.json_body["event"]["content"], new_body)
+ self.assertEqual(channel.json_body["event"]["content"], orig_body)
self._assert_edit_bundle(
channel.json_body["event"], edit_event_id, edit_event_content
)
@@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
]
assert_bundle(self._find_event_in_chunk(chunk))
- def test_annotation(self) -> None:
- """
- Test that annotations get correctly bundled.
- """
- # Setup by sending a variety of relations.
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- def assert_annotations(bundled_aggregations: JsonDict) -> None:
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- bundled_aggregations,
- )
-
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
-
- def test_annotation_to_annotation(self) -> None:
- """Any relation to an annotation should be ignored."""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- event_id = channel.json_body["event_id"]
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id
- )
-
- # Fetch the initial annotation event to see if it has bundled aggregations.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}",
- access_token=self.user_token,
- )
- self.assertEquals(200, channel.code, channel.json_body)
- # The first annotationt should not have any bundled aggregations.
- self.assertNotIn("m.relations", channel.json_body["unsigned"])
-
def test_reference(self) -> None:
"""
Test that references get correctly bundled.
@@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
def test_thread(self) -> None:
"""
@@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
@@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_2 = channel.json_body["event_id"]
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2
)
+ reference_event_id = channel.json_body["event_id"]
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
@@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assert_dict(
{
"m.relations": {
- RelationTypes.ANNOTATION: {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 1},
- ]
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reference_event_id}]
},
}
},
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6)
def test_nested_thread(self) -> None:
"""
@@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
thread_summary = relations_dict[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event_in_thread = thread_summary["latest_event"]
- self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
# The latest event in the thread should have the edit appear under the
# bundled aggregations.
self.assertDictContainsSubset(
@@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_id = channel.json_body["event_id"]
- # Annotate the thread.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ # Make a reference to the thread.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id
)
+ reference_event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
@@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(
channel.json_body["unsigned"].get("m.relations"),
{
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
},
)
@@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(
thread_message["unsigned"].get("m.relations"),
{
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
},
)
@@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
Note that the spec allows for a server to return additional fields beyond
what is specified.
"""
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
+ reference_event_id = channel.json_body["event_id"]
# Note that the sync filter does not include "unsigned" as a field.
filter = urllib.parse.quote_plus(
@@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Ensure there's bundled aggregations on it.
self.assertIn("unsigned", parent_event)
- self.assertIn("m.relations", parent_event["unsigned"])
+ self.assertEqual(
+ parent_event["unsigned"].get("m.relations"),
+ {
+ RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
+ },
+ )
class RelationIgnoredUserTestCase(BaseRelationsTestCase):
@@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
return before_aggregations[relation_type], after_aggregations[relation_type]
- def test_annotation(self) -> None:
- """Annotations should ignore"""
- # Send 2 from us, 2 from the to be ignored user.
- allowed_event_ids = []
- ignored_event_ids = []
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
- allowed_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
- allowed_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="a",
- access_token=self.user2_token,
- )
- ignored_event_ids.append(channel.json_body["event_id"])
- channel = self._send_relation(
- RelationTypes.ANNOTATION,
- "m.reaction",
- key="c",
- access_token=self.user2_token,
- )
- ignored_event_ids.append(channel.json_body["event_id"])
-
- before_aggregations, after_aggregations = self._test_ignored_user(
- RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids
- )
-
- self.assertCountEqual(
- before_aggregations["chunk"],
- [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- {"type": "m.reaction", "key": "c", "count": 1},
- ],
- )
-
- self.assertCountEqual(
- after_aggregations["chunk"],
- [
- {"type": "m.reaction", "key": "a", "count": 1},
- {"type": "m.reaction", "key": "b", "count": 1},
- ],
- )
-
def test_reference(self) -> None:
- """Annotations should ignore"""
+ """Aggregations should exclude reference relations from ignored users"""
channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
allowed_event_ids = [channel.json_body["event_id"]]
@@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
)
def test_thread(self) -> None:
- """Annotations should ignore"""
+ """Aggregations should exclude thread releations from ignored users"""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
allowed_event_ids = [channel.json_body["event_id"]]
@@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
for t in threads
]
- def test_redact_relation_annotation(self) -> None:
- """
- Test that annotations of an event are properly handled after the
- annotation is redacted.
-
- The redacted relation should not be included in bundled aggregations or
- the response to relations.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- to_redact_event_id = channel.json_body["event_id"]
-
- channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- unredacted_event_id = channel.json_body["event_id"]
-
- # Both relations should exist.
- event_ids = self._get_related_events()
- relations = self._get_bundled_aggregations()
- self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
- self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
- )
-
- # Redact one of the reactions.
- self._redact(to_redact_event_id)
-
- # The unredacted relation should still exist.
- event_ids = self._get_related_events()
- relations = self._get_bundled_aggregations()
- self.assertEquals(event_ids, [unredacted_event_id])
- self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
- )
-
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.
@@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
is redacted.
"""
# Add a relation
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
related_event_id = channel.json_body["event_id"]
# The relations should exist.
event_ids = self._get_related_events()
relations = self._get_bundled_aggregations()
self.assertEqual(len(event_ids), 1)
- self.assertIn(RelationTypes.ANNOTATION, relations)
+ self.assertIn(RelationTypes.REFERENCE, relations)
# Redact the original event.
self._redact(self.parent_id)
@@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id])
self.assertEquals(
- relations["m.annotation"],
- {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+ relations[RelationTypes.REFERENCE],
+ {"chunk": [{"event_id": related_event_id}]},
)
def test_redact_parent_thread(self) -> None:
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c0eb5d01a6..8dbd64be55 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
class RendezvousServletTestCase(unittest.HomeserverTestCase):
-
servlets = [
rendezvous.register_servlets,
]
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
data = {"reason": None, "score": None}
self._assert_status(400, data)
+ def test_cannot_report_nonexistent_event(self) -> None:
+ """
+ Tests that we don't accept event reports for events which do not exist.
+ """
+ channel = self.make_request(
+ "POST",
+ f"rooms/{self.room_id}/report/$nonsenseeventid:test",
+ {"reason": "i am very sad"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(404, channel.code, msg=channel.result["body"])
+
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
+ assert isinstance(first_event_id, str)
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
+ assert isinstance(valid_event_id, str)
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
@@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Check that we can still access state events that were sent before the event that
# has been purged.
- self.get_event(room_id, create_event.event_id)
+ self.get_event(room_id, bool(create_event))
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
event = self.get_success(self.store.get_event(event_id, allow_none=True))
@@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNone(event)
return {}
- self.assertIsNotNone(event)
+ assert event is not None
time_now = self.clock.time_msec()
serialized = self.serializer.serialize_event(event, time_now)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase):
servlets = [room.register_servlets, room.register_deprecated_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
@@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase):
rmcreator_id = "@notme:red"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
self.uncreated_rmid = "!aa:test"
@@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(30, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(36, channel.resource_usage.db_txn_count)
+ self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase):
class RoomJoinTestCase(RoomBase):
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
hijack_auth = False
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
# Register the user who does the searching
self.user_id2 = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
@@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
self.url = b"/_matrix/client/r0/publicRooms"
config = self.default_config()
@@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["allow_public_rooms_without_auth"] = True
self.hs = self.setup_test_homeserver(config=config)
@@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase):
class ContextTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
class ThreepidInviteTestCase(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -3382,8 +3371,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3438,13 +3427,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
"""
Test allowing/blocking threepid invites with a spam-check module.
- In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3563,8 +3553,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
)
event.internal_metadata.outlier = True
+ persistence = self._storage_controllers.persistence
+ assert persistence is not None
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self) -> None:
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_identity_handler()
- identity_handler.lookup_3pid = Mock(
+ identity_handler.lookup_3pid = Mock( # type: ignore[assignment]
side_effect=AssertionError("This should not get called")
)
@@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
event_source.get_new_events(
user=UserID.from_string(self.other_user_id),
from_key=0,
- limit=None,
+ limit=10,
room_ids=[room_id],
is_guest=False,
)
@@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
@@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b9047194dd..9c876c7a32 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -41,7 +41,6 @@ from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
-
user_id = "@apple:test"
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
class SyncTypingTests(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
class ExcludeRoomTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..3b99513707 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
+
# patch the rules module with a Mock which will return False for some event
# types
async def check(
@@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
+
# first patch the event checker so that it will modify the event
async def check(
ev: EventBase, state: StateMap[EventBase]
@@ -425,7 +428,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
- if event.is_state and event.type == EventTypes.PowerLevels:
+ if event.is_state() and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
"room_id": event.room_id,
@@ -931,3 +934,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right parameters
self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+ """Tests that the on_add_user_third_party_identifier and
+ on_remove_user_third_party_identifier module callbacks are called
+ just before associating and removing a 3PID to/from an account.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_add_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_add_user_third_party_identifier_callback_mock
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked add callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_add_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ # Now remove the 3PID from the user
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [],
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+ def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+ self,
+ ) -> None:
+ """Tests that the on_remove_user_third_party_identifier module callback is called
+ when a user is deactivated and their third-party ID associations are deleted.
+ """
+ # Pretend to be a Synapse module and register both callbacks as mocks.
+ third_party_rules = self.hs.get_third_party_event_rules()
+ on_remove_user_third_party_identifier_callback_mock = Mock(
+ return_value=make_awaitable(None)
+ )
+ third_party_rules._on_threepid_bind_callbacks.append(
+ on_remove_user_third_party_identifier_callback_mock
+ )
+
+ # Register an admin user.
+ self.register_user("admin", "password", admin=True)
+ admin_tok = self.login("admin", "password")
+
+ # Also register a normal user we can modify.
+ user_id = self.register_user("user", "password")
+
+ # Add a 3PID to the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "threepids": [
+ {
+ "medium": "email",
+ "address": "foo@example.com",
+ },
+ ],
+ },
+ access_token=admin_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Now deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % user_id,
+ {
+ "deactivated": True,
+ },
+ access_token=admin_tok,
+ )
+
+ # Check that the mocked remove callback was called with the appropriate
+ # 3PID details.
+ self.assertEqual(channel.code, 200, channel.json_body)
+ on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+ args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+ self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3086e1b565..d8dc56261a 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
- self.mock_key = "foo"
+
+ # Here we make sure that we're setting all the fields that HttpTransactionCache
+ # uses to build the transaction key.
+ self.mock_request = Mock()
+ self.mock_request.path = b"/foo/bar"
+ self.mock_requester = Mock()
+ self.mock_requester.app_service = None
+ self.mock_requester.is_guest = False
+ self.mock_requester.access_token_id = 1234
@defer.inlineCallbacks
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
- res = yield self.cache.fetch_or_execute(
- self.mock_key, cb, "some_arg", keyword="arg"
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
)
cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response)
@@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
- res = yield self.cache.fetch_or_execute(
- self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request,
+ self.mock_requester,
+ cb,
+ "some_arg",
+ keyword="arg",
+ changing_args=i,
)
self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work
@@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertIs(current_context(), c1)
self.assertEqual(res, (1, {}))
@@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
- yield self.cache.fetch_or_execute(self.mock_key, cb)
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
- yield self.cache.fetch_or_execute(self.mock_key, cb)
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
- res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+ res = yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb
+ )
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
- yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+ yield self.cache.fetch_or_execute_request(
+ self.mock_request, self.mock_requester, cb, "an arg"
+ )
# no longer using cache
self.assertEqual(cb.call_count, 2)
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.room_id, EventTypes.Tombstone, ""
)
)
- self.assertIsNotNone(tombstone_event)
+ assert tombstone_event is not None
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
# Check that the new room exists.
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr
from typing_extensions import Literal
+from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -67,6 +68,7 @@ class RestHelper:
"""
hs: HomeServer
+ reactor: MemoryReactorClock
site: Site
auth_user_id: Optional[str]
@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {})
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}"
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
path,
@@ -488,7 +490,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+ channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
- self.hs.get_reactor(),
- FakeSite(resource, self.hs.get_reactor()),
+ self.reactor,
+ FakeSite(resource, self.reactor),
"POST",
path,
content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
"account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
- Returns the result of the final token login.
+ Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token and device id.
+ return self.login_via_token(login_token, expected_status), grant
+
+ def login_via_token(
+ self,
+ login_token: str,
+ expected_status: int = 200,
+ ) -> JsonDict:
+ """Submit the matrix login token to the login API, which gives us our
+ matrix access token and device id.Log in (as a new user) via OIDC
+
+ Returns the result of the token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
"/login",
@@ -684,7 +704,7 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body, grant
+ return channel.json_body
def auth_via_oidc(
self,
@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
uri,
@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
- channel = make_request(
- self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
- )
+ channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 23f227aed6..b59d9dfd4d 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -31,7 +31,6 @@ from tests.utils import MockClock
class MediaRetentionTestCase(unittest.HomeserverTestCase):
-
ONE_DAY_IN_MS = 24 * 60 * 60 * 1000
THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 2c321f8d04..e91dc581c2 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -26,8 +26,8 @@ from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
-from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.rest.media.media_repository_resource import MediaRepositoryResource
+from synapse.rest.media.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["url_preview_enabled"] = True
config["max_spider_size"] = 9999999
@@ -83,7 +82,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
config["media_store_path"] = self.media_store_path
provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
@@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
@@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics: str = "TCP",
) -> IResolutionReceiver:
-
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
if hostName not in self.lookups:
@@ -660,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"""If the preview image doesn't exist, ensure some data is returned."""
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
- end_content = (
+ result = (
b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
)
@@ -681,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
- % (len(end_content),)
- + end_content
+ % (len(result),)
+ + result
)
self.pump()
@@ -691,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# The image should not be in the result.
self.assertNotIn("og:image", channel.json_body)
+ def test_oembed_failure(self) -> None:
+ """If the autodiscovered oEmbed URL fails, ensure some data is returned."""
+ self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+ result = b"""
+ <title>oEmbed Autodiscovery Fail</title>
+ <link rel="alternate" type="application/json+oembed"
+ href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json"
+ title="matrixdotorg" />
+ """
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ )
+ % (len(result),)
+ + result
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # The image should not be in the result.
+ self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail")
+
def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py
deleted file mode 100644
index b1ee10cfcc..0000000000
--- a/tests/rest/media/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
deleted file mode 100644
index c73179151a..0000000000
--- a/tests/rest/media/v1/test_base.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest.media.v1._base import get_filename_from_headers
-
-from tests import unittest
-
-
-class GetFileNameFromHeadersTests(unittest.TestCase):
- # input -> expected result
- TEST_CASES = {
- b"inline; filename=abc.txt": "abc.txt",
- b'inline; filename="azerty"': "azerty",
- b'inline; filename="aze%20rty"': "aze%20rty",
- b'inline; filename="aze"rty"': 'aze"rty',
- b'inline; filename="azer;ty"': "azer;ty",
- b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
- }
-
- def tests(self) -> None:
- for hdr, expected in self.TEST_CASES.items():
- res = get_filename_from_headers({b"Content-Disposition": [hdr]})
- self.assertEqual(
- res,
- expected,
- f"expected output for {hdr!r} to be {expected} but was {res}",
- )
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
deleted file mode 100644
index 43e6f0f70a..0000000000
--- a/tests/rest/media/v1/test_filepath.py
+++ /dev/null
@@ -1,595 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import inspect
-import os
-from typing import Iterable
-
-from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check
-
-from tests import unittest
-
-
-class MediaFilePathsTestCase(unittest.TestCase):
- def setUp(self) -> None:
- super().setUp()
-
- self.filepaths = MediaFilePaths("/media_store")
-
- def test_local_media_filepath(self) -> None:
- """Test local media paths"""
- self.assertEqual(
- self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
- "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
- self.assertEqual(
- self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"),
- "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
-
- def test_local_media_thumbnail(self) -> None:
- """Test local media thumbnail paths"""
- self.assertEqual(
- self.filepaths.local_media_thumbnail_rel(
- "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
- ),
- "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
- )
- self.assertEqual(
- self.filepaths.local_media_thumbnail(
- "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
- ),
- "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
- )
-
- def test_local_media_thumbnail_dir(self) -> None:
- """Test local media thumbnail directory paths"""
- self.assertEqual(
- self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
- "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
-
- def test_remote_media_filepath(self) -> None:
- """Test remote media paths"""
- self.assertEqual(
- self.filepaths.remote_media_filepath_rel(
- "example.com", "GerZNDnDZVjsOtardLuwfIBg"
- ),
- "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
- self.assertEqual(
- self.filepaths.remote_media_filepath(
- "example.com", "GerZNDnDZVjsOtardLuwfIBg"
- ),
- "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
-
- def test_remote_media_thumbnail(self) -> None:
- """Test remote media thumbnail paths"""
- self.assertEqual(
- self.filepaths.remote_media_thumbnail_rel(
- "example.com",
- "GerZNDnDZVjsOtardLuwfIBg",
- 800,
- 600,
- "image/jpeg",
- "scale",
- ),
- "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
- )
- self.assertEqual(
- self.filepaths.remote_media_thumbnail(
- "example.com",
- "GerZNDnDZVjsOtardLuwfIBg",
- 800,
- 600,
- "image/jpeg",
- "scale",
- ),
- "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
- )
-
- def test_remote_media_thumbnail_legacy(self) -> None:
- """Test old-style remote media thumbnail paths"""
- self.assertEqual(
- self.filepaths.remote_media_thumbnail_rel_legacy(
- "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg"
- ),
- "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
- )
-
- def test_remote_media_thumbnail_dir(self) -> None:
- """Test remote media thumbnail directory paths"""
- self.assertEqual(
- self.filepaths.remote_media_thumbnail_dir(
- "example.com", "GerZNDnDZVjsOtardLuwfIBg"
- ),
- "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
-
- def test_url_cache_filepath(self) -> None:
- """Test URL cache paths"""
- self.assertEqual(
- self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
- "url_cache/2020-01-02/GerZNDnDZVjsOtar",
- )
- self.assertEqual(
- self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"),
- "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
- )
-
- def test_url_cache_filepath_legacy(self) -> None:
- """Test old-style URL cache paths"""
- self.assertEqual(
- self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
- "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
- self.assertEqual(
- self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"),
- "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
-
- def test_url_cache_filepath_dirs_to_delete(self) -> None:
- """Test URL cache cleanup paths"""
- self.assertEqual(
- self.filepaths.url_cache_filepath_dirs_to_delete(
- "2020-01-02_GerZNDnDZVjsOtar"
- ),
- ["/media_store/url_cache/2020-01-02"],
- )
-
- def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
- """Test old-style URL cache cleanup paths"""
- self.assertEqual(
- self.filepaths.url_cache_filepath_dirs_to_delete(
- "GerZNDnDZVjsOtardLuwfIBg"
- ),
- [
- "/media_store/url_cache/Ge/rZ",
- "/media_store/url_cache/Ge",
- ],
- )
-
- def test_url_cache_thumbnail(self) -> None:
- """Test URL cache thumbnail paths"""
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_rel(
- "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
- ),
- "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
- )
- self.assertEqual(
- self.filepaths.url_cache_thumbnail(
- "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
- ),
- "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
- )
-
- def test_url_cache_thumbnail_legacy(self) -> None:
- """Test old-style URL cache thumbnail paths"""
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_rel(
- "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
- ),
- "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
- )
- self.assertEqual(
- self.filepaths.url_cache_thumbnail(
- "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
- ),
- "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
- )
-
- def test_url_cache_thumbnail_directory(self) -> None:
- """Test URL cache thumbnail directory paths"""
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_directory_rel(
- "2020-01-02_GerZNDnDZVjsOtar"
- ),
- "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
- )
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"),
- "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
- )
-
- def test_url_cache_thumbnail_directory_legacy(self) -> None:
- """Test old-style URL cache thumbnail directory paths"""
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_directory_rel(
- "GerZNDnDZVjsOtardLuwfIBg"
- ),
- "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"),
- "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- )
-
- def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
- """Test URL cache thumbnail cleanup paths"""
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_dirs_to_delete(
- "2020-01-02_GerZNDnDZVjsOtar"
- ),
- [
- "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
- "/media_store/url_cache_thumbnails/2020-01-02",
- ],
- )
-
- def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
- """Test old-style URL cache thumbnail cleanup paths"""
- self.assertEqual(
- self.filepaths.url_cache_thumbnail_dirs_to_delete(
- "GerZNDnDZVjsOtardLuwfIBg"
- ),
- [
- "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
- "/media_store/url_cache_thumbnails/Ge/rZ",
- "/media_store/url_cache_thumbnails/Ge",
- ],
- )
-
- def test_server_name_validation(self) -> None:
- """Test validation of server names"""
- self._test_path_validation(
- [
- "remote_media_filepath_rel",
- "remote_media_filepath",
- "remote_media_thumbnail_rel",
- "remote_media_thumbnail",
- "remote_media_thumbnail_rel_legacy",
- "remote_media_thumbnail_dir",
- ],
- parameter="server_name",
- valid_values=[
- "matrix.org",
- "matrix.org:8448",
- "matrix-federation.matrix.org",
- "matrix-federation.matrix.org:8448",
- "10.1.12.123",
- "10.1.12.123:8448",
- "[fd00:abcd::ffff]",
- "[fd00:abcd::ffff]:8448",
- ],
- invalid_values=[
- "/matrix.org",
- "matrix.org/..",
- "matrix.org\x00",
- "",
- ".",
- "..",
- "/",
- ],
- )
-
- def test_file_id_validation(self) -> None:
- """Test validation of local, remote and legacy URL cache file / media IDs"""
- # File / media IDs get split into three parts to form paths, consisting of the
- # first two characters, next two characters and rest of the ID.
- valid_file_ids = [
- "GerZNDnDZVjsOtardLuwfIBg",
- # Unexpected, but produces an acceptable path:
- "GerZN", # "N" becomes the last directory
- ]
- invalid_file_ids = [
- "/erZNDnDZVjsOtardLuwfIBg",
- "Ge/ZNDnDZVjsOtardLuwfIBg",
- "GerZ/DnDZVjsOtardLuwfIBg",
- "GerZ/..",
- "G\x00rZNDnDZVjsOtardLuwfIBg",
- "Ger\x00NDnDZVjsOtardLuwfIBg",
- "GerZNDnDZVjsOtardLuwfIBg\x00",
- "",
- "Ge",
- "GerZ",
- "GerZ.",
- "..rZNDnDZVjsOtardLuwfIBg",
- "Ge..NDnDZVjsOtardLuwfIBg",
- "GerZ..",
- "GerZ/",
- ]
-
- self._test_path_validation(
- [
- "local_media_filepath_rel",
- "local_media_filepath",
- "local_media_thumbnail_rel",
- "local_media_thumbnail",
- "local_media_thumbnail_dir",
- # Legacy URL cache media IDs
- "url_cache_filepath_rel",
- "url_cache_filepath",
- # `url_cache_filepath_dirs_to_delete` is tested below.
- "url_cache_thumbnail_rel",
- "url_cache_thumbnail",
- "url_cache_thumbnail_directory_rel",
- "url_cache_thumbnail_directory",
- "url_cache_thumbnail_dirs_to_delete",
- ],
- parameter="media_id",
- valid_values=valid_file_ids,
- invalid_values=invalid_file_ids,
- )
-
- # `url_cache_filepath_dirs_to_delete` ignores what would be the last path
- # component, so only the first 4 characters matter.
- self._test_path_validation(
- [
- "url_cache_filepath_dirs_to_delete",
- ],
- parameter="media_id",
- valid_values=valid_file_ids,
- invalid_values=[
- "/erZNDnDZVjsOtardLuwfIBg",
- "Ge/ZNDnDZVjsOtardLuwfIBg",
- "G\x00rZNDnDZVjsOtardLuwfIBg",
- "Ger\x00NDnDZVjsOtardLuwfIBg",
- "",
- "Ge",
- "..rZNDnDZVjsOtardLuwfIBg",
- "Ge..NDnDZVjsOtardLuwfIBg",
- ],
- )
-
- self._test_path_validation(
- [
- "remote_media_filepath_rel",
- "remote_media_filepath",
- "remote_media_thumbnail_rel",
- "remote_media_thumbnail",
- "remote_media_thumbnail_rel_legacy",
- "remote_media_thumbnail_dir",
- ],
- parameter="file_id",
- valid_values=valid_file_ids,
- invalid_values=invalid_file_ids,
- )
-
- def test_url_cache_media_id_validation(self) -> None:
- """Test validation of URL cache media IDs"""
- self._test_path_validation(
- [
- "url_cache_filepath_rel",
- "url_cache_filepath",
- # `url_cache_filepath_dirs_to_delete` only cares about the date prefix
- "url_cache_thumbnail_rel",
- "url_cache_thumbnail",
- "url_cache_thumbnail_directory_rel",
- "url_cache_thumbnail_directory",
- "url_cache_thumbnail_dirs_to_delete",
- ],
- parameter="media_id",
- valid_values=[
- "2020-01-02_GerZNDnDZVjsOtar",
- "2020-01-02_G", # Unexpected, but produces an acceptable path
- ],
- invalid_values=[
- "2020-01-02",
- "2020-01-02-",
- "2020-01-02-.",
- "2020-01-02-..",
- "2020-01-02-/",
- "2020-01-02-/GerZNDnDZVjsOtar",
- "2020-01-02-GerZNDnDZVjsOtar/..",
- "2020-01-02-GerZNDnDZVjsOtar\x00",
- ],
- )
-
- def test_content_type_validation(self) -> None:
- """Test validation of thumbnail content types"""
- self._test_path_validation(
- [
- "local_media_thumbnail_rel",
- "local_media_thumbnail",
- "remote_media_thumbnail_rel",
- "remote_media_thumbnail",
- "remote_media_thumbnail_rel_legacy",
- "url_cache_thumbnail_rel",
- "url_cache_thumbnail",
- ],
- parameter="content_type",
- valid_values=[
- "image/jpeg",
- ],
- invalid_values=[
- "", # ValueError: not enough values to unpack
- "image/jpeg/abc", # ValueError: too many values to unpack
- "image/jpeg\x00",
- ],
- )
-
- def test_thumbnail_method_validation(self) -> None:
- """Test validation of thumbnail methods"""
- self._test_path_validation(
- [
- "local_media_thumbnail_rel",
- "local_media_thumbnail",
- "remote_media_thumbnail_rel",
- "remote_media_thumbnail",
- "url_cache_thumbnail_rel",
- "url_cache_thumbnail",
- ],
- parameter="method",
- valid_values=[
- "crop",
- "scale",
- ],
- invalid_values=[
- "/scale",
- "scale/..",
- "scale\x00",
- "/",
- ],
- )
-
- def _test_path_validation(
- self,
- methods: Iterable[str],
- parameter: str,
- valid_values: Iterable[str],
- invalid_values: Iterable[str],
- ) -> None:
- """Test that the specified methods validate the named parameter as expected
-
- Args:
- methods: The names of `MediaFilePaths` methods to test
- parameter: The name of the parameter to test
- valid_values: A list of parameter values that are expected to be accepted
- invalid_values: A list of parameter values that are expected to be rejected
-
- Raises:
- AssertionError: If a value was accepted when it should have failed
- validation.
- ValueError: If a value failed validation when it should have been accepted.
- """
- for method in methods:
- get_path = getattr(self.filepaths, method)
-
- parameters = inspect.signature(get_path).parameters
- kwargs = {
- "server_name": "matrix.org",
- "media_id": "GerZNDnDZVjsOtardLuwfIBg",
- "file_id": "GerZNDnDZVjsOtardLuwfIBg",
- "width": 800,
- "height": 600,
- "content_type": "image/jpeg",
- "method": "scale",
- }
-
- if get_path.__name__.startswith("url_"):
- kwargs["media_id"] = "2020-01-02_GerZNDnDZVjsOtar"
-
- kwargs = {k: v for k, v in kwargs.items() if k in parameters}
- kwargs.pop(parameter)
-
- for value in valid_values:
- kwargs[parameter] = value
- get_path(**kwargs)
- # No exception should be raised
-
- for value in invalid_values:
- with self.assertRaises(ValueError):
- kwargs[parameter] = value
- path_or_list = get_path(**kwargs)
- self.fail(
- f"{value!r} unexpectedly passed validation: "
- f"{method} returned {path_or_list!r}"
- )
-
-
-class MediaFilePathsJailTestCase(unittest.TestCase):
- def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None:
- """Passes a relative path through the jail check.
-
- Args:
- filepaths: The `MediaFilePaths` instance.
- path: A path relative to the media store directory.
-
- Raises:
- ValueError: If the jail check fails.
- """
-
- @_wrap_with_jail_check(relative=True)
- def _make_relative_path(self: MediaFilePaths, path: str) -> str:
- return path
-
- _make_relative_path(filepaths, path)
-
- def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None:
- """Passes an absolute path through the jail check.
-
- Args:
- filepaths: The `MediaFilePaths` instance.
- path: A path relative to the media store directory.
-
- Raises:
- ValueError: If the jail check fails.
- """
-
- @_wrap_with_jail_check(relative=False)
- def _make_absolute_path(self: MediaFilePaths, path: str) -> str:
- return os.path.join(self.base_path, path)
-
- _make_absolute_path(filepaths, path)
-
- def test_traversal_inside(self) -> None:
- """Test the jail check for paths that stay within the media directory."""
- # Despite the `../`s, these paths still lie within the media directory and it's
- # expected for the jail check to allow them through.
- # These paths ought to trip the other checks in place and should never be
- # returned.
- filepaths = MediaFilePaths("/media_store")
- path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar"
- self._check_relative_path(filepaths, path)
- self._check_absolute_path(filepaths, path)
-
- def test_traversal_outside(self) -> None:
- """Test that the jail check fails for paths that escape the media directory."""
- filepaths = MediaFilePaths("/media_store")
- path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar"
- with self.assertRaises(ValueError):
- self._check_relative_path(filepaths, path)
- with self.assertRaises(ValueError):
- self._check_absolute_path(filepaths, path)
-
- def test_traversal_reentry(self) -> None:
- """Test the jail check for paths that exit and re-enter the media directory."""
- # These paths lie outside the media directory if it is a symlink, and inside
- # otherwise. Ideally the check should fail, but this proves difficult.
- # This test documents the behaviour for this edge case.
- # These paths ought to trip the other checks in place and should never be
- # returned.
- filepaths = MediaFilePaths("/media_store")
- path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar"
- self._check_relative_path(filepaths, path)
- self._check_absolute_path(filepaths, path)
-
- def test_symlink(self) -> None:
- """Test that a symlink does not cause the jail check to fail."""
- media_store_path = self.mktemp()
-
- # symlink the media store directory
- os.symlink("/mnt/synapse/media_store", media_store_path)
-
- # Test that relative and absolute paths don't trip the check
- # NB: `media_store_path` is a relative path
- filepaths = MediaFilePaths(media_store_path)
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
-
- filepaths = MediaFilePaths(os.path.abspath(media_store_path))
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
-
- def test_symlink_subdirectory(self) -> None:
- """Test that a symlinked subdirectory does not cause the jail check to fail."""
- media_store_path = self.mktemp()
- os.mkdir(media_store_path)
-
- # symlink `url_cache/`
- os.symlink(
- "/mnt/synapse/media_store_url_cache",
- os.path.join(media_store_path, "url_cache"),
- )
-
- # Test that relative and absolute paths don't trip the check
- # NB: `media_store_path` is a relative path
- filepaths = MediaFilePaths(media_store_path)
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
-
- filepaths = MediaFilePaths(os.path.abspath(media_store_path))
- self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
- self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar")
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
deleted file mode 100644
index 1062081a06..0000000000
--- a/tests/rest/media/v1/test_html_preview.py
+++ /dev/null
@@ -1,542 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest.media.v1.preview_html import (
- _get_html_media_encodings,
- decode_body,
- parse_html_to_open_graph,
- summarize_paragraphs,
-)
-
-from tests import unittest
-
-try:
- import lxml
-except ImportError:
- lxml = None
-
-
-class SummarizeTestCase(unittest.TestCase):
- if not lxml:
- skip = "url preview feature requires lxml"
-
- def test_long_summarize(self) -> None:
- example_paras = [
- """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
- Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
- Troms county, Norway. The administrative centre of the municipality is
- the city of Tromsø. Outside of Norway, Tromso and Tromsö are
- alternative spellings of the city.Tromsø is considered the northernmost
- city in the world with a population above 50,000. The most populous town
- north of it is Alta, Norway, with a population of 14,272 (2013).""",
- """Tromsø lies in Northern Norway. The municipality has a population of
- (2015) 72,066, but with an annual influx of students it has over 75,000
- most of the year. It is the largest urban area in Northern Norway and the
- third largest north of the Arctic Circle (following Murmansk and Norilsk).
- Most of Tromsø, including the city centre, is located on the island of
- Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,
- Tromsøya had a population of 36,088. Substantial parts of the urban area
- are also situated on the mainland to the east, and on parts of Kvaløya—a
- large island to the west. Tromsøya is connected to the mainland by the Tromsø
- Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the
- Sandnessund Bridge. Tromsø Airport connects the city to many destinations
- in Europe. The city is warmer than most other places located on the same
- latitude, due to the warming effect of the Gulf Stream.""",
- """The city centre of Tromsø contains the highest number of old wooden
- houses in Northern Norway, the oldest house dating from 1789. The Arctic
- Cathedral, a modern church from 1965, is probably the most famous landmark
- in Tromsø. The city is a cultural centre for its region, with several
- festivals taking place in the summer. Some of Norway's best-known
- musicians, Torbjørn Brundtland and Svein Berge of the electronica duo
- Röyksopp and Lene Marlin grew up and started their careers in Tromsø.
- Noted electronic musician Geir Jenssen also hails from Tromsø.""",
- ]
-
- desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
-
- self.assertEqual(
- desc,
- "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
- " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
- " Troms county, Norway. The administrative centre of the municipality is"
- " the city of Tromsø. Outside of Norway, Tromso and Tromsö are"
- " alternative spellings of the city.Tromsø is considered the northernmost"
- " city in the world with a population above 50,000. The most populous town"
- " north of it is Alta, Norway, with a population of 14,272 (2013).",
- )
-
- desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
-
- self.assertEqual(
- desc,
- "Tromsø lies in Northern Norway. The municipality has a population of"
- " (2015) 72,066, but with an annual influx of students it has over 75,000"
- " most of the year. It is the largest urban area in Northern Norway and the"
- " third largest north of the Arctic Circle (following Murmansk and Norilsk)."
- " Most of Tromsø, including the city centre, is located on the island of"
- " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,"
- " Tromsøya had a population of 36,088. Substantial parts of the urban…",
- )
-
- def test_short_summarize(self) -> None:
- example_paras = [
- "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
- " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
- " Troms county, Norway.",
- "Tromsø lies in Northern Norway. The municipality has a population of"
- " (2015) 72,066, but with an annual influx of students it has over 75,000"
- " most of the year.",
- "The city centre of Tromsø contains the highest number of old wooden"
- " houses in Northern Norway, the oldest house dating from 1789. The Arctic"
- " Cathedral, a modern church from 1965, is probably the most famous landmark"
- " in Tromsø.",
- ]
-
- desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
-
- self.assertEqual(
- desc,
- "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
- " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
- " Troms county, Norway.\n"
- "\n"
- "Tromsø lies in Northern Norway. The municipality has a population of"
- " (2015) 72,066, but with an annual influx of students it has over 75,000"
- " most of the year.",
- )
-
- def test_small_then_large_summarize(self) -> None:
- example_paras = [
- "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
- " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
- " Troms county, Norway.",
- "Tromsø lies in Northern Norway. The municipality has a population of"
- " (2015) 72,066, but with an annual influx of students it has over 75,000"
- " most of the year."
- " The city centre of Tromsø contains the highest number of old wooden"
- " houses in Northern Norway, the oldest house dating from 1789. The Arctic"
- " Cathedral, a modern church from 1965, is probably the most famous landmark"
- " in Tromsø.",
- ]
-
- desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
- self.assertEqual(
- desc,
- "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
- " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
- " Troms county, Norway.\n"
- "\n"
- "Tromsø lies in Northern Norway. The municipality has a population of"
- " (2015) 72,066, but with an annual influx of students it has over 75,000"
- " most of the year. The city centre of Tromsø contains the highest number"
- " of old wooden houses in Northern Norway, the oldest house dating from"
- " 1789. The Arctic Cathedral, a modern church from…",
- )
-
-
-class OpenGraphFromHtmlTestCase(unittest.TestCase):
- if not lxml:
- skip = "url preview feature requires lxml"
-
- def test_simple(self) -> None:
- html = b"""
- <html>
- <head><title>Foo</title></head>
- <body>
- Some text.
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
-
- def test_comment(self) -> None:
- html = b"""
- <html>
- <head><title>Foo</title></head>
- <body>
- <!-- HTML comment -->
- Some text.
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
-
- def test_comment2(self) -> None:
- html = b"""
- <html>
- <head><title>Foo</title></head>
- <body>
- Some text.
- <!-- HTML comment -->
- Some more text.
- <p>Text</p>
- More text
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(
- og,
- {
- "og:title": "Foo",
- "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text",
- },
- )
-
- def test_script(self) -> None:
- html = b"""
- <html>
- <head><title>Foo</title></head>
- <body>
- <script> (function() {})() </script>
- Some text.
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
-
- def test_missing_title(self) -> None:
- html = b"""
- <html>
- <body>
- Some text.
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
-
- # Another variant is a title with no content.
- html = b"""
- <html>
- <head><title></title></head>
- <body>
- <h1>Title</h1>
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
-
- def test_h1_as_title(self) -> None:
- html = b"""
- <html>
- <meta property="og:description" content="Some text."/>
- <body>
- <h1>Title</h1>
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
-
- def test_empty_description(self) -> None:
- """Description tags with empty content should be ignored."""
- html = b"""
- <html>
- <meta property="og:description" content=""/>
- <meta property="og:description"/>
- <meta name="description" content=""/>
- <meta name="description"/>
- <meta name="description" content="Finally!"/>
- <body>
- <h1>Title</h1>
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
-
- def test_missing_title_and_broken_h1(self) -> None:
- html = b"""
- <html>
- <body>
- <h1><a href="foo"/></h1>
- Some text.
- </body>
- </html>
- """
-
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
-
- self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
-
- def test_empty(self) -> None:
- """Test a body with no data in it."""
- html = b""
- tree = decode_body(html, "http://example.com/test.html")
- self.assertIsNone(tree)
-
- def test_no_tree(self) -> None:
- """A valid body with no tree in it."""
- html = b"\x00"
- tree = decode_body(html, "http://example.com/test.html")
- self.assertIsNone(tree)
-
- def test_xml(self) -> None:
- """Test decoding XML and ensure it works properly."""
- # Note that the strip() call is important to ensure the xml tag starts
- # at the initial byte.
- html = b"""
- <?xml version="1.0" encoding="UTF-8"?>
-
- <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
- <html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en" lang="en">
- <head><title>Foo</title></head><body>Some text.</body></html>
- """.strip()
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
-
- def test_invalid_encoding(self) -> None:
- """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
- html = b"""
- <html>
- <head><title>Foo</title></head>
- <body>
- Some text.
- </body>
- </html>
- """
- tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
-
- def test_invalid_encoding2(self) -> None:
- """A body which doesn't match the sent character encoding."""
- # Note that this contains an invalid UTF-8 sequence in the title.
- html = b"""
- <html>
- <head><title>\xff\xff Foo</title></head>
- <body>
- Some text.
- </body>
- </html>
- """
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
-
- def test_windows_1252(self) -> None:
- """A body which uses cp1252, but doesn't declare that."""
- html = b"""
- <html>
- <head><title>\xf3</title></head>
- <body>
- Some text.
- </body>
- </html>
- """
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
-
- def test_twitter_tag(self) -> None:
- """Twitter card tags should be used if nothing else is available."""
- html = b"""
- <html>
- <meta name="twitter:card" content="summary">
- <meta name="twitter:description" content="Description">
- <meta name="twitter:site" content="@matrixdotorg">
- </html>
- """
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(
- og,
- {
- "og:title": None,
- "og:description": "Description",
- "og:site_name": "@matrixdotorg",
- },
- )
-
- # But they shouldn't override Open Graph values.
- html = b"""
- <html>
- <meta name="twitter:card" content="summary">
- <meta name="twitter:description" content="Description">
- <meta property="og:description" content="Real Description">
- <meta name="twitter:site" content="@matrixdotorg">
- <meta property="og:site_name" content="matrix.org">
- </html>
- """
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(
- og,
- {
- "og:title": None,
- "og:description": "Real Description",
- "og:site_name": "matrix.org",
- },
- )
-
- def test_nested_nodes(self) -> None:
- """A body with some nested nodes. Tests that we iterate over children
- in the right order (and don't reverse the order of the text)."""
- html = b"""
- <a href="somewhere">Welcome <b>the bold <u>and underlined text <svg>
- with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a>
- """
- tree = decode_body(html, "http://example.com/test.html")
- og = parse_html_to_open_graph(tree)
- self.assertEqual(
- og,
- {
- "og:title": None,
- "og:description": "Welcome\n\nthe bold\n\nand underlined text\n\nand\n\nsome\n\ntail text",
- },
- )
-
-
-class MediaEncodingTestCase(unittest.TestCase):
- def test_meta_charset(self) -> None:
- """A character encoding is found via the meta tag."""
- encodings = _get_html_media_encodings(
- b"""
- <html>
- <head><meta charset="ascii">
- </head>
- </html>
- """,
- "text/html",
- )
- self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
-
- # A less well-formed version.
- encodings = _get_html_media_encodings(
- b"""
- <html>
- <head>< meta charset = ascii>
- </head>
- </html>
- """,
- "text/html",
- )
- self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
-
- def test_meta_charset_underscores(self) -> None:
- """A character encoding contains underscore."""
- encodings = _get_html_media_encodings(
- b"""
- <html>
- <head><meta charset="Shift_JIS">
- </head>
- </html>
- """,
- "text/html",
- )
- self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
-
- def test_xml_encoding(self) -> None:
- """A character encoding is found via the meta tag."""
- encodings = _get_html_media_encodings(
- b"""
- <?xml version="1.0" encoding="ascii"?>
- <html>
- </html>
- """,
- "text/html",
- )
- self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
-
- def test_meta_xml_encoding(self) -> None:
- """Meta tags take precedence over XML encoding."""
- encodings = _get_html_media_encodings(
- b"""
- <?xml version="1.0" encoding="ascii"?>
- <html>
- <head><meta charset="UTF-16">
- </head>
- </html>
- """,
- "text/html",
- )
- self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
-
- def test_content_type(self) -> None:
- """A character encoding is found via the Content-Type header."""
- # Test a few variations of the header.
- headers = (
- 'text/html; charset="ascii";',
- "text/html;charset=ascii;",
- 'text/html; charset="ascii"',
- "text/html; charset=ascii",
- 'text/html; charset="ascii;',
- 'text/html; charset=ascii";',
- )
- for header in headers:
- encodings = _get_html_media_encodings(b"", header)
- self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
-
- def test_fallback(self) -> None:
- """A character encoding cannot be found in the body or header."""
- encodings = _get_html_media_encodings(b"", "text/html")
- self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
- def test_duplicates(self) -> None:
- """Ensure each encoding is only attempted once."""
- encodings = _get_html_media_encodings(
- b"""
- <?xml version="1.0" encoding="utf8"?>
- <html>
- <head><meta charset="UTF-8">
- </head>
- </html>
- """,
- 'text/html; charset="UTF_8"',
- )
- self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
- def test_unknown_invalid(self) -> None:
- """A character encoding should be ignored if it is unknown or invalid."""
- encodings = _get_html_media_encodings(
- b"""
- <html>
- <head><meta charset="invalid">
- </head>
- </html>
- """,
- 'text/html; charset="invalid"',
- )
- self.assertEqual(list(encodings), ["utf-8", "cp1252"])
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
deleted file mode 100644
index d18fc13c21..0000000000
--- a/tests/rest/media/v1/test_media_storage.py
+++ /dev/null
@@ -1,782 +0,0 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os
-import shutil
-import tempfile
-from binascii import unhexlify
-from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
-from unittest.mock import Mock
-from urllib import parse
-
-import attr
-from parameterized import parameterized, parameterized_class
-from PIL import Image as Image
-from typing_extensions import Literal
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.errors import Codes
-from synapse.events import EventBase
-from synapse.events.spamcheck import load_legacy_spam_checkers
-from synapse.logging.context import make_deferred_yieldable
-from synapse.module_api import ModuleApi
-from synapse.rest import admin
-from synapse.rest.client import login
-from synapse.rest.media.v1._base import FileInfo
-from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
-from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
-from synapse.server import HomeServer
-from synapse.types import RoomAlias
-from synapse.util import Clock
-
-from tests import unittest
-from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import SMALL_PNG
-from tests.utils import default_config
-
-
-class MediaStorageTests(unittest.HomeserverTestCase):
-
- needs_threadpool = True
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
- self.addCleanup(shutil.rmtree, self.test_dir)
-
- self.primary_base_path = os.path.join(self.test_dir, "primary")
- self.secondary_base_path = os.path.join(self.test_dir, "secondary")
-
- hs.config.media.media_store_path = self.primary_base_path
-
- storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
-
- self.filepaths = MediaFilePaths(self.primary_base_path)
- self.media_storage = MediaStorage(
- hs, self.primary_base_path, self.filepaths, storage_providers
- )
-
- def test_ensure_media_is_in_local_cache(self) -> None:
- media_id = "some_media_id"
- test_body = "Test\n"
-
- # First we create a file that is in a storage provider but not in the
- # local primary media store
- rel_path = self.filepaths.local_media_filepath_rel(media_id)
- secondary_path = os.path.join(self.secondary_base_path, rel_path)
-
- os.makedirs(os.path.dirname(secondary_path))
-
- with open(secondary_path, "w") as f:
- f.write(test_body)
-
- # Now we run ensure_media_is_in_local_cache, which should copy the file
- # to the local cache.
- file_info = FileInfo(None, media_id)
-
- # This uses a real blocking threadpool so we have to wait for it to be
- # actually done :/
- x = defer.ensureDeferred(
- self.media_storage.ensure_media_is_in_local_cache(file_info)
- )
-
- # Hotloop until the threadpool does its job...
- self.wait_on_thread(x)
-
- local_path = self.get_success(x)
-
- self.assertTrue(os.path.exists(local_path))
-
- # Asserts the file is under the expected local cache directory
- self.assertEqual(
- os.path.commonprefix([self.primary_base_path, local_path]),
- self.primary_base_path,
- )
-
- with open(local_path) as f:
- body = f.read()
-
- self.assertEqual(test_body, body)
-
-
-@attr.s(auto_attribs=True, slots=True, frozen=True)
-class _TestImage:
- """An image for testing thumbnailing with the expected results
-
- Attributes:
- data: The raw image to thumbnail
- content_type: The type of the image as a content type, e.g. "image/png"
- extension: The extension associated with the format, e.g. ".png"
- expected_cropped: The expected bytes from cropped thumbnailing, or None if
- test should just check for success.
- expected_scaled: The expected bytes from scaled thumbnailing, or None if
- test should just check for a valid image returned.
- expected_found: True if the file should exist on the server, or False if
- a 404/400 is expected.
- unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
- False if the thumbnailing should succeed or a normal 404 is expected.
- """
-
- data: bytes
- content_type: bytes
- extension: bytes
- expected_cropped: Optional[bytes] = None
- expected_scaled: Optional[bytes] = None
- expected_found: bool = True
- unable_to_thumbnail: bool = False
-
-
-@parameterized_class(
- ("test_image",),
- [
- # small png
- (
- _TestImage(
- SMALL_PNG,
- b"image/png",
- b".png",
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000020000000200806"
- b"000000737a7af40000001a49444154789cedc101010000008220"
- b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
- b"44ae426082"
- ),
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000d49444154789c636060606000000005"
- b"0001a5f645400000000049454e44ae426082"
- ),
- ),
- ),
- # small png with transparency.
- (
- _TestImage(
- unhexlify(
- b"89504e470d0a1a0a0000000d49484452000000010000000101000"
- b"00000376ef9240000000274524e5300010194fdae0000000a4944"
- b"4154789c636800000082008177cd72b60000000049454e44ae426"
- b"082"
- ),
- b"image/png",
- b".png",
- # Note that we don't check the output since it varies across
- # different versions of Pillow.
- ),
- ),
- # small lossless webp
- (
- _TestImage(
- unhexlify(
- b"524946461a000000574542505650384c0d0000002f0000001007"
- b"1011118888fe0700"
- ),
- b"image/webp",
- b".webp",
- ),
- ),
- # an empty file
- (
- _TestImage(
- b"",
- b"image/gif",
- b".gif",
- expected_found=False,
- unable_to_thumbnail=True,
- ),
- ),
- ],
-)
-class MediaRepoTests(unittest.HomeserverTestCase):
-
- hijack_auth = True
- user_id = "@test:user"
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
- self.fetches = []
-
- def get_file(
- destination: str,
- path: str,
- output_stream: BinaryIO,
- args: Optional[Dict[str, Union[str, List[str]]]] = None,
- max_size: Optional[int] = None,
- ) -> Deferred:
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
-
- def write_to(r):
- data, response = r
- output_stream.write(data)
- return response
-
- d = Deferred()
- d.addCallback(write_to)
- self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
-
- client = Mock()
- client.get_file = get_file
-
- self.storage_path = self.mktemp()
- self.media_store_path = self.mktemp()
- os.mkdir(self.storage_path)
- os.mkdir(self.media_store_path)
-
- config = self.default_config()
- config["media_store_path"] = self.media_store_path
- config["max_image_pixels"] = 2000000
-
- provider_config = {
- "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
- "store_local": True,
- "store_synchronous": False,
- "store_remote": True,
- "config": {"directory": self.storage_path},
- }
- config["media_storage_providers"] = [provider_config]
-
- hs = self.setup_test_homeserver(config=config, federation_http_client=client)
-
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
- media_resource = hs.get_media_repository_resource()
- self.download_resource = media_resource.children[b"download"]
- self.thumbnail_resource = media_resource.children[b"thumbnail"]
- self.store = hs.get_datastores().main
- self.media_repo = hs.get_media_repository()
-
- self.media_id = "example.com/12345"
-
- def _req(
- self, content_disposition: Optional[bytes], include_content_type: bool = True
- ) -> FakeChannel:
- channel = make_request(
- self.reactor,
- FakeSite(self.download_resource, self.reactor),
- "GET",
- self.media_id,
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- # We've made one fetch, to example.com, using the media URL, and asking
- # the other server not to do a remote fetch
- self.assertEqual(len(self.fetches), 1)
- self.assertEqual(self.fetches[0][1], "example.com")
- self.assertEqual(
- self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
- )
- self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
-
- headers = {
- b"Content-Length": [b"%d" % (len(self.test_image.data))],
- }
-
- if include_content_type:
- headers[b"Content-Type"] = [self.test_image.content_type]
-
- if content_disposition:
- headers[b"Content-Disposition"] = [content_disposition]
-
- self.fetches[0][0].callback(
- (self.test_image.data, (len(self.test_image.data), headers))
- )
-
- self.pump()
- self.assertEqual(channel.code, 200)
-
- return channel
-
- def test_handle_missing_content_type(self) -> None:
- channel = self._req(
- b"inline; filename=out" + self.test_image.extension,
- include_content_type=False,
- )
- headers = channel.headers
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
- )
-
- def test_disposition_filename_ascii(self) -> None:
- """
- If the filename is filename=<ascii> then Synapse will decode it as an
- ASCII string, and use filename= in the response.
- """
- channel = self._req(b"inline; filename=out" + self.test_image.extension)
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
- )
- self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"),
- [b"inline; filename=out" + self.test_image.extension],
- )
-
- def test_disposition_filenamestar_utf8escaped(self) -> None:
- """
- If the filename is filename=*utf8''<utf8 escaped> then Synapse will
- correctly decode it as the UTF-8 string, and use filename* in the
- response.
- """
- filename = parse.quote("\u2603".encode()).encode("ascii")
- channel = self._req(
- b"inline; filename*=utf-8''" + filename + self.test_image.extension
- )
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
- )
- self.assertEqual(
- headers.getRawHeaders(b"Content-Disposition"),
- [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
- )
-
- def test_disposition_none(self) -> None:
- """
- If there is no filename, one isn't passed on in the Content-Disposition
- of the request.
- """
- channel = self._req(None)
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
- )
- self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
-
- def test_thumbnail_crop(self) -> None:
- """Test that a cropped remote thumbnail is available."""
- self._test_thumbnail(
- "crop",
- self.test_image.expected_cropped,
- expected_found=self.test_image.expected_found,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- def test_thumbnail_scale(self) -> None:
- """Test that a scaled remote thumbnail is available."""
- self._test_thumbnail(
- "scale",
- self.test_image.expected_scaled,
- expected_found=self.test_image.expected_found,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- def test_invalid_type(self) -> None:
- """An invalid thumbnail type is never available."""
- self._test_thumbnail(
- "invalid",
- None,
- expected_found=False,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- @unittest.override_config(
- {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
- )
- def test_no_thumbnail_crop(self) -> None:
- """
- Override the config to generate only scaled thumbnails, but request a cropped one.
- """
- self._test_thumbnail(
- "crop",
- None,
- expected_found=False,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- @unittest.override_config(
- {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
- )
- def test_no_thumbnail_scale(self) -> None:
- """
- Override the config to generate only cropped thumbnails, but request a scaled one.
- """
- self._test_thumbnail(
- "scale",
- None,
- expected_found=False,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- def test_thumbnail_repeated_thumbnail(self) -> None:
- """Test that fetching the same thumbnail works, and deleting the on disk
- thumbnail regenerates it.
- """
- self._test_thumbnail(
- "scale",
- self.test_image.expected_scaled,
- expected_found=self.test_image.expected_found,
- unable_to_thumbnail=self.test_image.unable_to_thumbnail,
- )
-
- if not self.test_image.expected_found:
- return
-
- # Fetching again should work, without re-requesting the image from the
- # remote.
- params = "?width=32&height=32&method=scale"
- channel = make_request(
- self.reactor,
- FakeSite(self.thumbnail_resource, self.reactor),
- "GET",
- self.media_id + params,
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- self.assertEqual(channel.code, 200)
- if self.test_image.expected_scaled:
- self.assertEqual(
- channel.result["body"],
- self.test_image.expected_scaled,
- channel.result["body"],
- )
-
- # Deleting the thumbnail on disk then re-requesting it should work as
- # Synapse should regenerate missing thumbnails.
- origin, media_id = self.media_id.split("/")
- info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
- file_id = info["filesystem_id"]
-
- thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
- origin, file_id
- )
- shutil.rmtree(thumbnail_dir, ignore_errors=True)
-
- channel = make_request(
- self.reactor,
- FakeSite(self.thumbnail_resource, self.reactor),
- "GET",
- self.media_id + params,
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- self.assertEqual(channel.code, 200)
- if self.test_image.expected_scaled:
- self.assertEqual(
- channel.result["body"],
- self.test_image.expected_scaled,
- channel.result["body"],
- )
-
- def _test_thumbnail(
- self,
- method: str,
- expected_body: Optional[bytes],
- expected_found: bool,
- unable_to_thumbnail: bool = False,
- ) -> None:
- """Test the given thumbnailing method works as expected.
-
- Args:
- method: The thumbnailing method to use (crop, scale).
- expected_body: The expected bytes from thumbnailing, or None if
- test should just check for a valid image.
- expected_found: True if the file should exist on the server, or False if
- a 404/400 is expected.
- unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
- False if the thumbnailing should succeed or a normal 404 is expected.
- """
-
- params = "?width=32&height=32&method=" + method
- channel = make_request(
- self.reactor,
- FakeSite(self.thumbnail_resource, self.reactor),
- "GET",
- self.media_id + params,
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- headers = {
- b"Content-Length": [b"%d" % (len(self.test_image.data))],
- b"Content-Type": [self.test_image.content_type],
- }
- self.fetches[0][0].callback(
- (self.test_image.data, (len(self.test_image.data), headers))
- )
- self.pump()
-
- if expected_found:
- self.assertEqual(channel.code, 200)
-
- self.assertEqual(
- channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
- [b"cross-origin"],
- )
-
- if expected_body is not None:
- self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
- )
- else:
- # ensure that the result is at least some valid image
- Image.open(BytesIO(channel.result["body"]))
- elif unable_to_thumbnail:
- # A 400 with a JSON body.
- self.assertEqual(channel.code, 400)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_UNKNOWN",
- "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
- },
- )
- else:
- # A 404 with a JSON body.
- self.assertEqual(channel.code, 404)
- self.assertEqual(
- channel.json_body,
- {
- "errcode": "M_NOT_FOUND",
- "error": "Not found [b'example.com', b'12345']",
- },
- )
-
- @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
- def test_same_quality(self, method: str, desired_size: int) -> None:
- """Test that choosing between thumbnails with the same quality rating succeeds.
-
- We are not particular about which thumbnail is chosen."""
- self.assertIsNotNone(
- self.thumbnail_resource._select_thumbnail(
- desired_width=desired_size,
- desired_height=desired_size,
- desired_method=method,
- desired_type=self.test_image.content_type,
- # Provide two identical thumbnails which are guaranteed to have the same
- # quality rating.
- thumbnail_infos=[
- {
- "thumbnail_width": 32,
- "thumbnail_height": 32,
- "thumbnail_method": method,
- "thumbnail_type": self.test_image.content_type,
- "thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension}",
- },
- {
- "thumbnail_width": 32,
- "thumbnail_height": 32,
- "thumbnail_method": method,
- "thumbnail_type": self.test_image.content_type,
- "thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension}",
- },
- ],
- file_id=f"image{self.test_image.extension}",
- url_cache=None,
- server_name=None,
- )
- )
-
- def test_x_robots_tag_header(self) -> None:
- """
- Tests that the `X-Robots-Tag` header is present, which informs web crawlers
- to not index, archive, or follow links in media.
- """
- channel = self._req(b"inline; filename=out" + self.test_image.extension)
-
- headers = channel.headers
- self.assertEqual(
- headers.getRawHeaders(b"X-Robots-Tag"),
- [b"noindex, nofollow, noarchive, noimageindex"],
- )
-
- def test_cross_origin_resource_policy_header(self) -> None:
- """
- Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
- allowing web clients to embed media from the downloads API.
- """
- channel = self._req(b"inline; filename=out" + self.test_image.extension)
-
- headers = channel.headers
-
- self.assertEqual(
- headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
- [b"cross-origin"],
- )
-
-
-class TestSpamCheckerLegacy:
- """A spam checker module that rejects all media that includes the bytes
- `evil`.
-
- Uses the legacy Spam-Checker API.
- """
-
- def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
- self.config = config
- self.api = api
-
- def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
- return config
-
- async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
- return False # allow all events
-
- async def user_may_invite(
- self,
- inviter_userid: str,
- invitee_userid: str,
- room_id: str,
- ) -> bool:
- return True # allow all invites
-
- async def user_may_create_room(self, userid: str) -> bool:
- return True # allow all room creations
-
- async def user_may_create_room_alias(
- self, userid: str, room_alias: RoomAlias
- ) -> bool:
- return True # allow all room aliases
-
- async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
- return True # allow publishing of all rooms
-
- async def check_media_file_for_spam(
- self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> bool:
- buf = BytesIO()
- await file_wrapper.write_chunks_to(buf.write)
-
- return b"evil" in buf.getvalue()
-
-
-class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
- servlets = [
- login.register_servlets,
- admin.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user = self.register_user("user", "pass")
- self.tok = self.login("user", "pass")
-
- # Allow for uploading and downloading to/from the media repo
- self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b"download"]
- self.upload_resource = self.media_repo.children[b"upload"]
-
- load_legacy_spam_checkers(hs)
-
- def default_config(self) -> Dict[str, Any]:
- config = default_config("test")
-
- config.update(
- {
- "spam_checker": [
- {
- "module": TestSpamCheckerLegacy.__module__
- + ".TestSpamCheckerLegacy",
- "config": {},
- }
- ]
- }
- )
-
- return config
-
- def test_upload_innocent(self) -> None:
- """Attempt to upload some innocent data that should be allowed."""
- self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
- )
-
- def test_upload_ban(self) -> None:
- """Attempt to upload some data that includes bytes "evil", which should
- get rejected by the spam checker.
- """
-
- data = b"Some evil data"
-
- self.helper.upload_media(
- self.upload_resource, data, tok=self.tok, expect_code=400
- )
-
-
-EVIL_DATA = b"Some evil data"
-EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API"
-
-
-class SpamCheckerTestCase(unittest.HomeserverTestCase):
- servlets = [
- login.register_servlets,
- admin.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user = self.register_user("user", "pass")
- self.tok = self.login("user", "pass")
-
- # Allow for uploading and downloading to/from the media repo
- self.media_repo = hs.get_media_repository_resource()
- self.download_resource = self.media_repo.children[b"download"]
- self.upload_resource = self.media_repo.children[b"upload"]
-
- hs.get_module_api().register_spam_checker_callbacks(
- check_media_file_for_spam=self.check_media_file_for_spam
- )
-
- async def check_media_file_for_spam(
- self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> Union[Codes, Literal["NOT_SPAM"]]:
- buf = BytesIO()
- await file_wrapper.write_chunks_to(buf.write)
-
- if buf.getvalue() == EVIL_DATA:
- return Codes.FORBIDDEN
- elif buf.getvalue() == EVIL_DATA_EXPERIMENT:
- return (Codes.FORBIDDEN, {})
- else:
- return "NOT_SPAM"
-
- def test_upload_innocent(self) -> None:
- """Attempt to upload some innocent data that should be allowed."""
- self.helper.upload_media(
- self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
- )
-
- def test_upload_ban(self) -> None:
- """Attempt to upload some data that includes bytes "evil", which should
- get rejected by the spam checker.
- """
-
- self.helper.upload_media(
- self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
- )
-
- self.helper.upload_media(
- self.upload_resource,
- EVIL_DATA_EXPERIMENT,
- tok=self.tok,
- expect_code=400,
- )
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
deleted file mode 100644
index 3f7f1dbab9..0000000000
--- a/tests/rest/media/v1/test_oembed.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-
-from parameterized import parameterized
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-from synapse.util import Clock
-
-from tests.unittest import HomeserverTestCase
-
-try:
- import lxml
-except ImportError:
- lxml = None
-
-
-class OEmbedTests(HomeserverTestCase):
- if not lxml:
- skip = "url preview feature requires lxml"
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.oembed = OEmbedProvider(hs)
-
- def parse_response(self, response: JsonDict) -> OEmbedResult:
- return self.oembed.parse_oembed_response(
- "https://test", json.dumps(response).encode("utf-8")
- )
-
- def test_version(self) -> None:
- """Accept versions that are similar to 1.0 as a string or int (or missing)."""
- for version in ("1.0", 1.0, 1):
- result = self.parse_response({"version": version})
- # An empty Open Graph response is an error, ensure the URL is included.
- self.assertIn("og:url", result.open_graph_result)
-
- # A missing version should be treated as 1.0.
- result = self.parse_response({"type": "link"})
- self.assertIn("og:url", result.open_graph_result)
-
- # Invalid versions should be rejected.
- for version in ("2.0", "1", 1.1, 0, None, {}, []):
- result = self.parse_response({"version": version, "type": "link"})
- # An empty Open Graph response is an error, ensure the URL is included.
- self.assertEqual({}, result.open_graph_result)
-
- def test_cache_age(self) -> None:
- """Ensure a cache-age is parsed properly."""
- # Correct-ish cache ages are allowed.
- for cache_age in ("1", 1.0, 1):
- result = self.parse_response({"cache_age": cache_age})
- self.assertEqual(result.cache_age, 1000)
-
- # Invalid cache ages are ignored.
- for cache_age in ("invalid", {}):
- result = self.parse_response({"cache_age": cache_age})
- self.assertIsNone(result.cache_age)
-
- # Cache age is optional.
- result = self.parse_response({})
- self.assertIsNone(result.cache_age)
-
- @parameterized.expand(
- [
- ("title", "title"),
- ("provider_name", "site_name"),
- ("thumbnail_url", "image"),
- ],
- name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}",
- )
- def test_property(self, oembed_property: str, open_graph_property: str) -> None:
- """Test properties which must be strings."""
- result = self.parse_response({oembed_property: "test"})
- self.assertIn(f"og:{open_graph_property}", result.open_graph_result)
- self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test")
-
- result = self.parse_response({oembed_property: 1})
- self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result)
-
- def test_author_name(self) -> None:
- """Test the author_name property."""
- result = self.parse_response({"author_name": "test"})
- self.assertEqual(result.author_name, "test")
-
- result = self.parse_response({"author_name": 1})
- self.assertIsNone(result.author_name)
-
- def test_rich(self) -> None:
- """Test a type of rich."""
- result = self.parse_response({"html": "test<img src='foo'>", "type": "rich"})
- self.assertIn("og:description", result.open_graph_result)
- self.assertIn("og:image", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:description"], "test")
- self.assertEqual(result.open_graph_result["og:image"], "foo")
-
- result = self.parse_response({"type": "rich"})
- self.assertNotIn("og:description", result.open_graph_result)
-
- result = self.parse_response({"html": 1, "type": "rich"})
- self.assertNotIn("og:description", result.open_graph_result)
-
- def test_photo(self) -> None:
- """Test a type of photo."""
- result = self.parse_response({"url": "test", "type": "photo"})
- self.assertIn("og:image", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:image"], "test")
-
- result = self.parse_response({"type": "photo"})
- self.assertNotIn("og:image", result.open_graph_result)
-
- result = self.parse_response({"url": 1, "type": "photo"})
- self.assertNotIn("og:image", result.open_graph_result)
-
- def test_video(self) -> None:
- """Test a type of video."""
- result = self.parse_response({"html": "test", "type": "video"})
- self.assertIn("og:type", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:type"], "video.other")
- self.assertIn("og:description", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:description"], "test")
-
- result = self.parse_response({"type": "video"})
- self.assertIn("og:type", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:type"], "video.other")
- self.assertNotIn("og:description", result.open_graph_result)
-
- result = self.parse_response({"url": 1, "type": "video"})
- self.assertIn("og:type", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:type"], "video.other")
- self.assertNotIn("og:description", result.open_graph_result)
-
- def test_link(self) -> None:
- """Test type of link."""
- result = self.parse_response({"type": "link"})
- self.assertIn("og:type", result.open_graph_result)
- self.assertEqual(result.open_graph_result["og:type"], "website")
-
- def test_title_html_entities(self) -> None:
- """Test HTML entities in title"""
- result = self.parse_response(
- {"title": "Why JSON isn’t a Good Configuration Language"}
- )
- self.assertEqual(
- result.open_graph_result["og:title"],
- "Why JSON isn’t a Good Configuration Language",
- )
|