From 9bb2eac71962970d02842bca441f4bcdbbf93a11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 15:29:09 -0500 Subject: Bump black from 22.12.0 to 23.1.0 (#15103) --- tests/rest/admin/test_device.py | 3 --- tests/rest/admin/test_media.py | 5 ----- tests/rest/admin/test_room.py | 1 - tests/rest/admin/test_server_notice.py | 1 - tests/rest/client/test_account.py | 4 ---- tests/rest/client/test_auth.py | 2 -- tests/rest/client/test_capabilities.py | 1 - tests/rest/client/test_consent.py | 1 - tests/rest/client/test_directory.py | 1 - tests/rest/client/test_ephemeral_message.py | 1 - tests/rest/client/test_events.py | 3 --- tests/rest/client/test_filter.py | 1 - tests/rest/client/test_login.py | 2 -- tests/rest/client/test_login_token_request.py | 1 - tests/rest/client/test_presence.py | 1 - tests/rest/client/test_profile.py | 3 --- tests/rest/client/test_register.py | 4 ---- tests/rest/client/test_rendezvous.py | 1 - tests/rest/client/test_rooms.py | 14 ++------------ tests/rest/client/test_sync.py | 3 --- tests/rest/client/test_third_party_rules.py | 3 +++ tests/rest/media/test_media_retention.py | 1 - tests/rest/media/v1/test_media_storage.py | 3 --- tests/rest/media/v1/test_url_preview.py | 3 --- 24 files changed, 5 insertions(+), 58 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 03f2112b07..aaa488bced 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -28,7 +28,6 @@ from tests import unittest class DeviceRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): class DevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index db77a45ae3..f41319a5b6 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -594,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -724,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -821,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 453a6e979c..9dbb778679 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, room.register_servlets, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index f71ff46d87..28b999573e 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class ServerNoticeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e2ee1a1766..2b05dffc7d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): class DeactivateTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): class WhoamiTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index a144610078..0d8fe77b88 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -52,7 +52,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): class FallbackAuthTests(unittest.HomeserverTestCase): - servlets = [ auth.register_servlets, register.register_servlets, @@ -60,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = True diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index d1751e1557..c16e8d43f4 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -26,7 +26,6 @@ from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index b1ca81a911..bb845179d3 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["form_secret"] = "123abc" diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 7a88aa2cda..6490e883bf 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class DirectoryTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, directory.register_servlets, diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 9fa1f82dfe..f31ebc8021 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -26,7 +26,6 @@ from tests import unittest class EphemeralMessageTestCase(unittest.HomeserverTestCase): - user_id = "@user:test" servlets = [ diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a9b7db9db2..54df2a252c 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = False config["enable_registration"] = True @@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") @@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 830762fd53..91678abf13 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" hijack_auth = True EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index ff5baa9f0a..62acf4f44e 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [ class LoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, ] diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index 6aedc1a11c..b8187db982 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, admin.register_servlets, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 67e16880e6..dcbb125a3b 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -35,7 +35,6 @@ class PresenceTestCase(unittest.HomeserverTestCase): servlets = [presence.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.presence_handler = Mock(spec=PresenceHandler) self.presence_handler.set_state.return_value = make_awaitable(None) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 8de5a342ae..27c93ad761 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -30,7 +30,6 @@ from tests import unittest class ProfileTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 4c561f9525..b228dba861 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register.register_servlets, @@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index c0eb5d01a6..8dbd64be55 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" class RendezvousServletTestCase(unittest.HomeserverTestCase): - servlets = [ rendezvous.register_servlets, ] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index cfad182b2f..4dd763096d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase): servlets = [room.register_servlets, room.register_deprecated_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver( "red", federation_http_client=None, @@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase): rmcreator_id = "@notme:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase): class RoomJoinTestCase(RoomBase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): hijack_auth = False def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Register the user who does the searching self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") @@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() @@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["allow_public_rooms_without_auth"] = True self.hs = self.setup_test_homeserver(config=config) @@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase): class ContextTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): class ThreepidInviteTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -3438,7 +3427,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): """ Test allowing/blocking threepid invites with a spam-check module. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index b9047194dd..9c876c7a32 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -41,7 +41,6 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): class ExcludeRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 5fa3440691..c0f93f898a 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ + # patch the rules module with a Mock which will return False for some event # types async def check( @@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_modify_event(self) -> None: """The module can return a modified version of the event""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index 23f227aed6..b59d9dfd4d 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -31,7 +31,6 @@ from tests.utils import MockClock class MediaRetentionTestCase(unittest.HomeserverTestCase): - ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 17a3b06a8e..8ed27179c4 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -52,7 +52,6 @@ from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): - needs_threadpool = True def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -207,7 +206,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches: List[ Tuple[ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", @@ -268,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2c321f8d04..6fcf60ce19 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["url_preview_enabled"] = True config["max_spider_size"] = 9999999 @@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] @@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: - resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) if hostName not in self.lookups: -- cgit 1.5.1 From 682151a464f688768d5bd8308e16bd4024ad2e57 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Feb 2023 16:08:53 -0500 Subject: Do not fail completely if oEmbed autodiscovery fails. (#15092) Previously if an autodiscovered oEmbed request failed (e.g. the oEmbed endpoint is down or does not exist) then the entire URL preview would fail. Instead we now return everything we can, even if this additional request fails. --- changelog.d/15092.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 33 ++++++++++++++------ tests/rest/media/v1/test_url_preview.py | 44 +++++++++++++++++++++++++-- 3 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 changelog.d/15092.bugfix (limited to 'tests/rest') diff --git a/changelog.d/15092.bugfix b/changelog.d/15092.bugfix new file mode 100644 index 0000000000..67509c5c69 --- /dev/null +++ b/changelog.d/15092.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a URL preview would break if the discovered oEmbed failed to download. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a8f6fd6b35..4a594ab9d8 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -163,6 +163,10 @@ class PreviewUrlResource(DirectServeJsonResource): 7. Stores the result in the database cache. 4. Returns the result. + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + The in-memory cache expires after 1 hour. Expired entries in the database cache (and their associated media files) are @@ -364,16 +368,25 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 6fcf60ce19..2acfccec61 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -657,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): """If the preview image doesn't exist, ensure some data is returned.""" self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] - end_content = ( + result = ( b"""""" ) @@ -678,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) - % (len(end_content),) - + end_content + % (len(result),) + + result ) self.pump() @@ -688,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase): # The image should not be in the result. self.assertNotIn("og:image", channel.json_body) + def test_oembed_failure(self) -> None: + """If the autodiscovered oEmbed URL fails, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + result = b""" + oEmbed Autodiscovery Fail + + """ + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(result),) + + result + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail") + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. -- cgit 1.5.1 From 1c95ddd09bbc46046a3412e7bb03a87aa3b6f65a Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 24 Feb 2023 13:15:29 -0800 Subject: Batch up storing state groups when creating new room (#14918) --- changelog.d/14918.misc | 1 + synapse/events/snapshot.py | 49 +++++++++++ synapse/handlers/message.py | 16 ++-- synapse/handlers/room.py | 37 ++++---- synapse/handlers/room_batch.py | 4 +- synapse/handlers/room_member.py | 13 ++- synapse/storage/databases/state/store.py | 119 ++++++++++++++++++++++++++ tests/handlers/test_message.py | 25 ++++-- tests/handlers/test_register.py | 3 +- tests/push/test_bulk_push_rule_evaluator.py | 13 +-- tests/rest/client/test_rooms.py | 4 +- tests/storage/test_event_chain.py | 6 +- tests/storage/test_state.py | 126 ++++++++++++++++++++++++++++ tests/unittest.py | 4 +- 14 files changed, 371 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14918.misc (limited to 'tests/rest') diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc new file mode 100644 index 0000000000..828794354a --- /dev/null +++ b/changelog.d/14918.misc @@ -0,0 +1 @@ +Batch up storing state groups when creating a new room. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e0d82ad81c..a91a5d1e3c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index aa90d0000d..e433d6b01f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -574,7 +574,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -721,8 +721,6 @@ class EventCreationHandler: current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -739,7 +737,7 @@ class EventCreationHandler: assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -764,8 +762,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - - return event, context + return event, unpersisted_context async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1002,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1016,6 +1013,7 @@ class EventCreationHandler: historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1190,7 +1188,6 @@ class EventCreationHandler: if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2046,7 +2043,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2055,6 +2052,7 @@ class EventCreationHandler: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a26ec02284..b1784638f4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,6 +51,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ class RoomCreationHandler: # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, ) = await self.event_creation_handler.create_event( requester, { @@ -225,6 +226,9 @@ class RoomCreationHandler: }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -1092,7 +1096,7 @@ class RoomCreationHandler: content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: """ Creates an event and associated event context. Args: @@ -1111,20 +1115,23 @@ class RoomCreationHandler: event_dict = create_event_dict(etype, content, **kwargs) - new_event, new_context = await self.event_creation_handler.create_event( + ( + new_event, + new_unpersisted_context, + ) = await self.event_creation_handler.create_event( creator, event_dict, prev_event_ids=prev_event, depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context + return new_event, new_unpersisted_context try: config = self._presets_dict[preset_config] @@ -1134,10 +1141,10 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - creation_event, creation_context = await create_event( + creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1181,7 +1188,6 @@ class RoomCreationHandler: power_event, power_context = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { @@ -1230,14 +1236,12 @@ class RoomCreationHandler: power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: @@ -1246,7 +1250,6 @@ class RoomCreationHandler: {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: @@ -1255,7 +1258,6 @@ class RoomCreationHandler: {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: @@ -1265,14 +1267,12 @@ class RoomCreationHandler: {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): event, context = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if config["encrypted"]: @@ -1284,9 +1284,16 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 5d4ca0e2d2..bf9df60218 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,7 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context = await self.event_creation_handler.create_event( + event, unpersisted_context = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +345,7 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a965c7ec76..de7476f300 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, event_dict, txn_id=txn_id, @@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True result_event = ( diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 89b1faa6c8..bf4cdfdf29 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], + room_id: str, + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + sg_before = prev_group + state_group_iter = iter(state_groups) + for event, context in events_and_context: + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + continue + + sg_after = next(state_group_iter) + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + sg_before = sg_after + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context.state_group_after_event, room_id, event.event_id) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + ( + context.state_group_after_event, + context.state_group_before_event, + ) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], + ) + return events_and_context + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 69d384442f..9691d66b48 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): return memberEvent, memberEventContext - def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + def _create_duplicate_event( + self, txn_id: str + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ @@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_suitably_random" - event1, context = self._create_duplicate_event(txn_id) + event1, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( self.handler.handle_new_client_event( @@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(event1.event_id, ret_event1.event_id) - event2, context = self._create_duplicate_event(txn_id) + event2, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, # so we want to make sure we test with different events. @@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_event` directly also does the right # thing. - event3, context = self._create_duplicate_event(txn_id) + event3, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event3)) + self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( @@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_events` directly also does the right # thing. - event4, context = self._create_duplicate_event(txn_id) + event4, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event4)) self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, context1 = self._create_duplicate_event(txn_id) - event2, context2 = self._create_duplicate_event(txn_id) + event1, unpersisted_context1 = self._create_duplicate_event(txn_id) + context1 = self.get_success(unpersisted_context1.persist(event1)) + event2, unpersisted_context2 = self._create_duplicate_event(txn_id) + context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with self.assertNotEqual(event1.event_id, event2.event_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1db99b3c00..aff1ec4758 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_event( requester, { @@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_creation_handler.handle_new_client_event( requester, events_and_context=[(event, context)] diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index dce6899e78..1458076a90 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): prev_event_ids=[pl_event_id], ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise @@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Mock the method which calculates push rules -- we do this instead of @@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) - + context = self.get_success(unpersisted_context.persist(event)) # Execute the push rule machinery. self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4dd763096d..a4900703c4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 73d11e7786..e39b63edac 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] @@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): assert state_ids1 is not None state1 = set(state_ids1.values()) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index e82c03f597..62aed6af0a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + def test_batched_state_group_storing(self) -> None: + creation_event = self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, "", {} + ) + state_to_event = self.get_success( + self.storage.state.get_state_groups( + self.room.to_string(), [creation_event.event_id] + ) + ) + current_state_group = list(state_to_event.keys())[0] + + # create some unpersisted events and event contexts to store against room + events_and_context = [] + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Name, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"name": "first rename of room"}, + }, + ) + + event1, unpersisted_context1 = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + events_and_context.append((event1, unpersisted_context1)) + + builder2 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "private"}, + }, + ) + + event2, unpersisted_context2 = self.get_success( + self.event_creation_handler.create_new_client_event(builder2) + ) + events_and_context.append((event2, unpersisted_context2)) + + builder3 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Message, + "sender": self.u_alice.to_string(), + "room_id": self.room.to_string(), + "content": {"body": "hello from event 3", "msgtype": "m.text"}, + }, + ) + + event3, unpersisted_context3 = self.get_success( + self.event_creation_handler.create_new_client_event(builder3) + ) + events_and_context.append((event3, unpersisted_context3)) + + builder4 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "public"}, + }, + ) + + event4, unpersisted_context4 = self.get_success( + self.event_creation_handler.create_new_client_event(builder4) + ) + events_and_context.append((event4, unpersisted_context4)) + + processed_events_and_context = self.get_success( + self.hs.get_datastores().state.store_state_deltas_for_batched( + events_and_context, self.room.to_string(), current_state_group + ) + ) + + # check that only state events are in state_groups, and all state events are in state_groups + res = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ) + + events = [] + for result in res: + self.assertNotIn(event3.event_id, result) + events.append(result.get("event_id")) + + for event, _ in processed_events_and_context: + if event.is_state(): + self.assertIn(event.event_id, events) + + # check that each unique state has state group in state_groups_state and that the + # type/state key is correct, and check that each state event's state group + # has an entry and prev event in state_group_edges + for event, context in processed_events_and_context: + if event.is_state(): + state = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ) + self.assertEqual(event.type, state[0].get("type")) + self.assertEqual(event.state_key, state[0].get("state_key")) + + groups = self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={"state_group": str(context.state_group_after_event)}, + retcols=("*",), + ) + ) + self.assertEqual( + context.state_group_before_event, groups[0].get("prev_state_group") + ) diff --git a/tests/unittest.py b/tests/unittest.py index b21e7f1221..f9160faa1d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase): event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creator.create_event( requester, { @@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase): prev_event_ids=prev_event_ids, ) ) - + context = self.get_success(unpersisted_context.persist(event)) if soft_failed: event.internal_metadata.soft_failed = True -- cgit 1.5.1 From 4fc8875876374ec8f97a3b3cc344a4e3abcf769f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Feb 2023 08:26:05 -0500 Subject: Refactor media modules. (#15146) * Removes the `v1` directory from `test.rest.media.v1`. * Moves the non-REST code from `synapse.rest.media.v1` to `synapse.media`. * Flatten the `v1` directory from `synapse.rest.media`, but leave compatiblity with 3rd party media repositories and spam checkers. --- changelog.d/15146.misc | 1 + synapse/_scripts/move_remote_media_to_new_store.py | 2 +- synapse/config/repository.py | 12 +- synapse/events/spamcheck.py | 4 +- synapse/media/_base.py | 479 ++++++++ synapse/media/filepath.py | 410 +++++++ synapse/media/media_repository.py | 1038 ++++++++++++++++ synapse/media/media_storage.py | 374 ++++++ synapse/media/oembed.py | 265 +++++ synapse/media/preview_html.py | 501 ++++++++ synapse/media/storage_provider.py | 181 +++ synapse/media/thumbnailer.py | 221 ++++ synapse/rest/media/config_resource.py | 41 + synapse/rest/media/download_resource.py | 75 ++ synapse/rest/media/media_repository_resource.py | 93 ++ synapse/rest/media/preview_url_resource.py | 869 ++++++++++++++ synapse/rest/media/thumbnail_resource.py | 554 +++++++++ synapse/rest/media/upload_resource.py | 108 ++ synapse/rest/media/v1/_base.py | 470 +------- synapse/rest/media/v1/config_resource.py | 41 - synapse/rest/media/v1/download_resource.py | 76 -- synapse/rest/media/v1/filepath.py | 410 ------- synapse/rest/media/v1/media_repository.py | 1112 ------------------ synapse/rest/media/v1/media_storage.py | 365 +----- synapse/rest/media/v1/oembed.py | 265 ----- synapse/rest/media/v1/preview_html.py | 501 -------- synapse/rest/media/v1/preview_url_resource.py | 871 -------------- synapse/rest/media/v1/storage_provider.py | 172 +-- synapse/rest/media/v1/thumbnail_resource.py | 555 --------- synapse/rest/media/v1/thumbnailer.py | 221 ---- synapse/rest/media/v1/upload_resource.py | 108 -- synapse/server.py | 6 +- tests/media/__init__.py | 13 + tests/media/test_base.py | 38 + tests/media/test_filepath.py | 595 ++++++++++ tests/media/test_html_preview.py | 542 +++++++++ tests/media/test_media_storage.py | 792 +++++++++++++ tests/media/test_oembed.py | 162 +++ tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_user.py | 2 +- tests/rest/media/test_url_preview.py | 1234 ++++++++++++++++++++ tests/rest/media/v1/__init__.py | 13 - tests/rest/media/v1/test_base.py | 38 - tests/rest/media/v1/test_filepath.py | 595 ---------- tests/rest/media/v1/test_html_preview.py | 542 --------- tests/rest/media/v1/test_media_storage.py | 792 ------------- tests/rest/media/v1/test_oembed.py | 162 --- tests/rest/media/v1/test_url_preview.py | 1234 -------------------- 48 files changed, 8612 insertions(+), 8545 deletions(-) create mode 100644 changelog.d/15146.misc create mode 100644 synapse/media/_base.py create mode 100644 synapse/media/filepath.py create mode 100644 synapse/media/media_repository.py create mode 100644 synapse/media/media_storage.py create mode 100644 synapse/media/oembed.py create mode 100644 synapse/media/preview_html.py create mode 100644 synapse/media/storage_provider.py create mode 100644 synapse/media/thumbnailer.py create mode 100644 synapse/rest/media/config_resource.py create mode 100644 synapse/rest/media/download_resource.py create mode 100644 synapse/rest/media/media_repository_resource.py create mode 100644 synapse/rest/media/preview_url_resource.py create mode 100644 synapse/rest/media/thumbnail_resource.py create mode 100644 synapse/rest/media/upload_resource.py delete mode 100644 synapse/rest/media/v1/config_resource.py delete mode 100644 synapse/rest/media/v1/download_resource.py delete mode 100644 synapse/rest/media/v1/filepath.py delete mode 100644 synapse/rest/media/v1/media_repository.py delete mode 100644 synapse/rest/media/v1/oembed.py delete mode 100644 synapse/rest/media/v1/preview_html.py delete mode 100644 synapse/rest/media/v1/preview_url_resource.py delete mode 100644 synapse/rest/media/v1/thumbnail_resource.py delete mode 100644 synapse/rest/media/v1/thumbnailer.py delete mode 100644 synapse/rest/media/v1/upload_resource.py create mode 100644 tests/media/__init__.py create mode 100644 tests/media/test_base.py create mode 100644 tests/media/test_filepath.py create mode 100644 tests/media/test_html_preview.py create mode 100644 tests/media/test_media_storage.py create mode 100644 tests/media/test_oembed.py create mode 100644 tests/rest/media/test_url_preview.py delete mode 100644 tests/rest/media/v1/__init__.py delete mode 100644 tests/rest/media/v1/test_base.py delete mode 100644 tests/rest/media/v1/test_filepath.py delete mode 100644 tests/rest/media/v1/test_html_preview.py delete mode 100644 tests/rest/media/v1/test_media_storage.py delete mode 100644 tests/rest/media/v1/test_oembed.py delete mode 100644 tests/rest/media/v1/test_url_preview.py (limited to 'tests/rest') diff --git a/changelog.d/15146.misc b/changelog.d/15146.misc new file mode 100644 index 0000000000..8de5f95239 --- /dev/null +++ b/changelog.d/15146.misc @@ -0,0 +1 @@ +Refactor the media modules. diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py index 819afaaca6..0dd36bee20 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -37,7 +37,7 @@ import os import shutil import sys -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.media.filepath import MediaFilePaths logger = logging.getLogger() diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2da40c09f0..ecb3edbe3a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -178,11 +178,13 @@ class ContentRepositoryConfig(Config): for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend - if provider_config["module"] == "file_system": - provider_config["module"] = ( - "synapse.rest.media.v1.storage_provider" - ".FileStorageProviderBackend" - ) + if ( + provider_config["module"] == "file_system" + or provider_config["module"] == "synapse.rest.media.v1.storage_provider" + ): + provider_config[ + "module" + ] = "synapse.media.storage_provider.FileStorageProviderBackend" provider_class, parsed_config = load_module( provider_config, ("media_storage_providers", "" % i) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 623a2c71ea..765c15bb51 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -33,8 +33,8 @@ from typing_extensions import Literal import synapse from synapse.api.errors import Codes from synapse.logging.opentracing import trace -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import ReadableFileWrapper +from synapse.media._base import FileInfo +from synapse.media.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import JsonDict, RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable diff --git a/synapse/media/_base.py b/synapse/media/_base.py new file mode 100644 index 0000000000..ef8334ae25 --- /dev/null +++ b/synapse/media/_base.py @@ -0,0 +1,479 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import urllib +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type + +import attr + +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender +from twisted.web.server import Request + +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name + +logger = logging.getLogger(__name__) + +# list all text content types that will have the charset default to UTF-8 when +# none is given +TEXT_CONTENT_TYPES = [ + "text/css", + "text/csv", + "text/html", + "text/calendar", + "text/plain", + "text/javascript", + "application/json", + "application/ld+json", + "application/rtf", + "image/svg+xml", + "text/xml", +] + + +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ + try: + # The type on postpath seems incorrect in Twisted 21.2.0. + postpath: List[bytes] = request.postpath # type: ignore + assert postpath + + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name_bytes, media_id_bytes = postpath[:2] + server_name = server_name_bytes.decode("utf-8") + media_id = media_id_bytes.decode("utf8") + + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + + file_name = None + if len(postpath) > 2: + try: + file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except Exception: + raise SynapseError( + 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN + ) + + +def respond_404(request: SynapseRequest) -> None: + respond_with_json( + request, + 404, + cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), + send_cors=True, + ) + + +async def respond_with_file( + request: SynapseRequest, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + add_file_headers(request, media_type, file_size, upload_name) + + with open(file_path, "rb") as f: + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + + finish_request(request) + else: + respond_404(request) + + +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. + """ + + def _quote(x: str) -> str: + return urllib.parse.quote(x.encode("utf-8")) + + # Default to a UTF-8 charset for text content types. + # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' + if media_type.lower() in TEXT_CONTENT_TYPES: + content_type = media_type + "; charset=UTF-8" + else: + content_type = media_type + + request.setHeader(b"Content-Type", content_type.encode("UTF-8")) + if upload_name: + # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. + # + # `filename` is defined to be a `value`, which is defined by RFC2616 + # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` + # is (essentially) a single US-ASCII word, and a `quoted-string` is a + # US-ASCII string surrounded by double-quotes, using backslash as an + # escape character. Note that %-encoding is *not* permitted. + # + # `filename*` is defined to be an `ext-value`, which is defined in + # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, + # where `value-chars` is essentially a %-encoded string in the given charset. + # + # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 + # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 + # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 + + # We avoid the quoted-string version of `filename`, because (a) synapse didn't + # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we + # may as well just do the filename* version. + if _can_encode_filename_as_token(upload_name): + disposition = "inline; filename=%s" % (upload_name,) + else: + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) + + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) + + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + + +# separators as defined in RFC2616. SP and HT are handled separately. +# see _can_encode_filename_as_token. +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} + + +def _can_encode_filename_as_token(x: str) -> bool: + for c in x: + # from RFC2616: + # + # token = 1* + # + # separators = "(" | ")" | "<" | ">" | "@" + # | "," | ";" | ":" | "\" | <"> + # | "/" | "[" | "]" | "?" | "=" + # | "{" | "}" | SP | HT + # + # CHAR = + # + # CTL = + # + if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: + return False + return True + + +async def respond_with_responder( + request: SynapseRequest, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + # If we have a responder we *must* use it as a context manager. + with responder: + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + logger.debug("Responding to media request with responder %s", responder) + add_file_headers(request, media_type, file_size, upload_name) + try: + await responder.write_to_consumer(request) + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + + finish_request(request) + + +class Responder(ABC): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + + @abstractmethod + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: + """Stream response into consumer + + Args: + consumer: The consumer to stream into. + + Returns: + Resolves once the response has finished being written + """ + raise NotImplementedError() + + def __enter__(self) -> None: # noqa: B027 + pass + + def __exit__( # noqa: B027 + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThumbnailInfo: + """Details about a generated thumbnail.""" + + width: int + height: int + method: str + # Content type of thumbnail, e.g. image/png + type: str + # The size of the media file, in bytes. + length: Optional[int] = None + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FileInfo: + """Details about a requested/uploaded file.""" + + # The server name where the media originated from, or None if local. + server_name: Optional[str] + # The local ID of the file. For local files this is the same as the media_id + file_id: str + # If the file is for the url preview cache + url_cache: bool = False + # Whether the file is a thumbnail or not. + thumbnail: Optional[ThumbnailInfo] = None + + # The below properties exist to maintain compatibility with third-party modules. + @property + def thumbnail_width(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.width + + @property + def thumbnail_height(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.height + + @property + def thumbnail_method(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.method + + @property + def thumbnail_type(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.type + + @property + def thumbnail_length(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.length + + +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: + """ + Get the filename of the downloaded file by inspecting the + Content-Disposition HTTP header. + + Args: + headers: The HTTP request headers. + + Returns: + The filename, or None. + """ + content_disposition = headers.get(b"Content-Disposition", [b""]) + + # No header, bail out. + if not content_disposition[0]: + return None + + _, params = _parse_header(content_disposition[0]) + + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get(b"filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith(b"utf-8''"): + upload_name_utf8 = upload_name_utf8[7:] + # We have a filename*= section. This MUST be ASCII, and any UTF-8 + # bytes are %-quoted. + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get(b"filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii.decode("ascii") + + # This may be None here, indicating we did not find a matching name. + return upload_name + + +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: + """Parse a Content-type like header. + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + line: header to be parsed + + Returns: + The main content-type, followed by the parameter dictionary + """ + parts = _parseparam(b";" + line) + key = next(parts) + pdict = {} + for p in parts: + i = p.find(b"=") + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1 :].strip() + + # strip double-quotes + if len(value) >= 2 and value[0:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') + pdict[name] = value + + return key, pdict + + +def _parseparam(s: bytes) -> Generator[bytes, None, None]: + """Generator which splits the input on ;, respecting double-quoted sequences + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + s: header to be parsed + + Returns: + The split input + """ + while s[:1] == b";": + s = s[1:] + + # look for the next ; + end = s.find(b";") + + # if there is an odd number of " marks between here and the next ;, skip to the + # next ; instead + while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: + end = s.find(b";", end + 1) + + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] diff --git a/synapse/media/filepath.py b/synapse/media/filepath.py new file mode 100644 index 0000000000..1f6441c412 --- /dev/null +++ b/synapse/media/filepath.py @@ -0,0 +1,410 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import re +import string +from typing import Any, Callable, List, TypeVar, Union, cast + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") + + +F = TypeVar("F", bound=Callable[..., str]) + + +def _wrap_in_base_path(func: F) -> F: + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + + @functools.wraps(func) + def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return cast(F, _wrapped) + + +GetPathMethod = TypeVar( + "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] +) + + +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: + """Wraps a path-returning method to check that the returned path(s) do not escape + the media store directory. + + The path-returning method may return either a single path, or a list of paths. + + The check is not expected to ever fail, unless `func` is missing a call to + `_validate_path_component`, or `_validate_path_component` is buggy. + + Args: + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. + + Returns: + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. + """ + + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner + + +ALLOWED_CHARACTERS = set( + string.ascii_letters + + string.digits + + "_-" + + ".[]:" # Domain names, IPv6 addresses and ports in server names +) +FORBIDDEN_NAMES = { + "", + os.path.curdir, # "." for the current platform + os.path.pardir, # ".." for the current platform +} + + +def _validate_path_component(name: str) -> str: + """Checks that the given string can be safely used as a path component + + Args: + name: The path component to check. + + Returns: + The path component if valid. + + Raises: + ValueError: If `name` cannot be safely used as a path component. + """ + if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: + raise ValueError(f"Invalid path component: {name!r}") + + return name + + +class MediaFilePaths: + """Describes where files are stored on disk. + + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ + + def __init__(self, primary_base_path: str): + self.base_path = primary_base_path + self.normalized_base_path = os.path.normpath(self.base_path) + + # Refuse to initialize if paths cannot be validated correctly for the current + # platform. + assert os.path.sep not in ALLOWED_CHARACTERS + assert os.path.altsep not in ALLOWED_CHARACTERS + # On Windows, paths have all sorts of weirdness which `_validate_path_component` + # does not consider. In any case, the remote media store can't work correctly + # for certain homeservers there, since ":"s aren't allowed in paths. + assert os.name == "posix" + + @_wrap_with_jail_check(relative=True) + def local_media_filepath_rel(self, media_id: str) -> str: + return os.path.join( + "local_content", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + @_wrap_with_jail_check(relative=True) + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + return os.path.join( + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), + ) + + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + @_wrap_with_jail_check(relative=False) + def local_media_thumbnail_dir(self, media_id: str) -> str: + """ + Retrieve the local store path of thumbnails of a given media_id + + Args: + media_id: The media ID to query. + Returns: + Path of local_thumbnails from media_id + """ + return os.path.join( + self.base_path, + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + @_wrap_with_jail_check(relative=True) + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: + return os.path.join( + "remote_content", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + ) + + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + @_wrap_with_jail_check(relative=True) + def remote_media_thumbnail_rel( + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + return os.path.join( + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), + ) + + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + + # Legacy path that was used to store thumbnails previously. + # Should be removed after some time, when most of the thumbnails are stored + # using the new path. + @_wrap_with_jail_check(relative=True) + def remote_media_thumbnail_rel_legacy( + self, server_name: str, file_id: str, width: int, height: int, content_type: str + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + return os.path.join( + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), + ) + + @_wrap_with_jail_check(relative=False) + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: + return os.path.join( + self.base_path, + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + ) + + @_wrap_with_jail_check(relative=True) + def url_cache_filepath_rel(self, media_id: str) -> str: + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + "url_cache", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) + else: + return os.path.join( + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + + @_wrap_with_jail_check(relative=False) + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[:10]) + ) + ] + else: + return [ + os.path.join( + self.base_path, + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[0:2]) + ), + ] + + @_wrap_with_jail_check(relative=True) + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + _validate_path_component(file_name), + ) + else: + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), + ) + + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + + @_wrap_with_jail_check(relative=True) + def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) + else: + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + url_cache_thumbnail_directory = _wrap_in_base_path( + url_cache_thumbnail_directory_rel + ) + + @_wrap_with_jail_check(relative=False) + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + ), + ] + else: + return [ + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + ), + ] diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py new file mode 100644 index 0000000000..b81e3c2b0c --- /dev/null +++ b/synapse/media/media_repository.py @@ -0,0 +1,1038 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import errno +import logging +import os +import shutil +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +from matrix_common.types.mxc_uri import MXCUri + +import twisted.internet.error +import twisted.web.http +from twisted.internet.defer import Deferred + +from synapse.api.errors import ( + FederationDeniedError, + HttpResponseException, + NotFoundError, + RequestSendFailed, + SynapseError, +) +from synapse.config.repository import ThumbnailRequirement +from synapse.http.site import SynapseRequest +from synapse.logging.context import defer_to_thread +from synapse.media._base import ( + FileInfo, + Responder, + ThumbnailInfo, + get_filename_from_headers, + respond_404, + respond_with_responder, +) +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_storage import MediaStorage +from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.thumbnailer import Thumbnailer, ThumbnailError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.async_helpers import Linearizer +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# How often to run the background job to update the "recently accessed" +# attribute of local and remote media. +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute +# How often to run the background job to check for local and remote media +# that should be purged according to the configured media retention settings. +MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour + + +class MediaRepository: + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.client = hs.get_federation_http_client() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.store = hs.get_datastores().main + self.max_upload_size = hs.config.media.max_upload_size + self.max_image_pixels = hs.config.media.max_image_pixels + + Thumbnailer.set_limits(self.max_image_pixels) + + self.primary_base_path: str = hs.config.media.media_store_path + self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) + + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + self.thumbnail_requirements = hs.config.media.thumbnail_requirements + + self.remote_media_linearizer = Linearizer(name="media_remote") + + self.recently_accessed_remotes: Set[Tuple[str, str]] = set() + self.recently_accessed_locals: Set[str] = set() + + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) + + # List of StorageProviders where we should search for media and + # potentially upload to. + storage_providers = [] + + for ( + clz, + provider_config, + wrapper_config, + ) in hs.config.media.media_storage_providers: + backend = clz(hs, provider_config) + provider = StorageProviderWrapper( + backend, + store_local=wrapper_config.store_local, + store_remote=wrapper_config.store_remote, + store_synchronous=wrapper_config.store_synchronous, + ) + storage_providers.append(provider) + + self.media_storage = MediaStorage( + self.hs, self.primary_base_path, self.filepaths, storage_providers + ) + + self.clock.looping_call( + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS + ) + + # Media retention configuration options + self._media_retention_local_media_lifetime_ms = ( + hs.config.media.media_retention_local_media_lifetime_ms + ) + self._media_retention_remote_media_lifetime_ms = ( + hs.config.media.media_retention_remote_media_lifetime_ms + ) + + # Check whether local or remote media retention is configured + if ( + hs.config.media.media_retention_local_media_lifetime_ms is not None + or hs.config.media.media_retention_remote_media_lifetime_ms is not None + ): + # Run the background job to apply media retention rules routinely, + # with the duration between runs dictated by the homeserver config. + self.clock.looping_call( + self._start_apply_media_retention_rules, + MEDIA_RETENTION_CHECK_PERIOD_MS, + ) + + def _start_update_recently_accessed(self) -> Deferred: + return run_as_background_process( + "update_recently_accessed_media", self._update_recently_accessed + ) + + def _start_apply_media_retention_rules(self) -> Deferred: + return run_as_background_process( + "apply_media_retention_rules", self._apply_media_retention_rules + ) + + async def _update_recently_accessed(self) -> None: + remote_media = self.recently_accessed_remotes + self.recently_accessed_remotes = set() + + local_media = self.recently_accessed_locals + self.recently_accessed_locals = set() + + await self.store.update_cached_last_access_time( + local_media, remote_media, self.clock.time_msec() + ) + + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: + """Mark the given media as recently accessed. + + Args: + server_name: Origin server of media, or None if local + media_id: The media ID of the content + """ + if server_name: + self.recently_accessed_remotes.add((server_name, media_id)) + else: + self.recently_accessed_locals.add(media_id) + + async def create_content( + self, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + ) -> MXCUri: + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type: The content type of the file. + upload_name: The name of the file, if provided. + content: A file like object that is the content to store + content_length: The length of the content + auth_user: The user_id of the uploader + + Returns: + The mxc url of the stored content + """ + + media_id = random_string(24) + + file_info = FileInfo(server_name=None, file_id=media_id) + + fname = await self.media_storage.store_file(content, file_info) + + logger.info("Stored local media in file %r", fname) + + await self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + + await self._generate_thumbnails(None, media_id, media_id, media_type) + + return MXCUri(self.server_name, media_id) + + async def get_local_media( + self, request: SynapseRequest, media_id: str, name: Optional[str] + ) -> None: + """Responds to requests for local media, if exists, or returns 404. + + Args: + request: The incoming request. + media_id: The media ID of the content. (This is the same as + the file_id for local content.) + name: Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Resolves once a response has successfully been written to request + """ + media_info = await self.store.get_local_media(media_id) + if not media_info or media_info["quarantined_by"]: + respond_404(request) + return + + self.mark_recently_accessed(None, media_id) + + media_type = media_info["media_type"] + if not media_type: + media_type = "application/octet-stream" + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + url_cache = media_info["url_cache"] + + file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) + + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( + request, responder, media_type, media_length, upload_name + ) + + async def get_remote_media( + self, + request: SynapseRequest, + server_name: str, + media_id: str, + name: Optional[str], + ) -> None: + """Respond to requests for remote media. + + Args: + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Resolves once a response has successfully been written to request + """ + if ( + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + self.mark_recently_accessed(server_name, media_id) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + async with self.remote_media_linearizer.queue(key): + responder, media_info = await self._get_remote_media_impl( + server_name, media_id + ) + + # We deliberately stream the file outside the lock + if responder: + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + await respond_with_responder( + request, responder, media_type, media_length, upload_name + ) + else: + respond_404(request) + + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + """Gets the media info associated with the remote file, downloading + if necessary. + + Args: + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + + Returns: + The media info of the file + """ + if ( + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + async with self.remote_media_linearizer.queue(key): + responder, media_info = await self._get_remote_media_impl( + server_name, media_id + ) + + # Ensure we actually use the responder so that it releases resources + if responder: + with responder: + pass + + return media_info + + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: + """Looks for media in local cache, if not there then attempt to + download from remote server. + + Args: + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the + remote server). + + Returns: + A tuple of responder and the media info of the file. + """ + media_info = await self.store.get_cached_remote_media(server_name, media_id) + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise generate a new + # one. + + # If we have an entry in the DB, try and look for it + if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + raise NotFoundError() + + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + + responder = await self.media_storage.fetch_media(file_info) + if responder: + return responder, media_info + + # Failed to find the file anywhere, lets download it. + + try: + media_info = await self._download_remote_file( + server_name, + media_id, + ) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) + + responder = await self.media_storage.fetch_media(file_info) + return responder, media_info + + async def _download_remote_file( + self, + server_name: str, + media_id: str, + ) -> dict: + """Attempt to download the remote file from the given server name, + using the given file_id as the local id. + + Args: + server_name: Originating server + media_id: The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + file_id: Local file ID + + Returns: + The media info of the file. + """ + + file_id = random_string(24) + + file_info = FileInfo(server_name=server_name, file_id=file_id) + + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + request_path = "/".join( + ("/_matrix/media/r0/download", server_name, media_id) + ) + try: + length, headers = await self.client.get_file( + server_name, + request_path, + output_stream=f, + max_size=self.max_upload_size, + args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false" + }, + ) + except RequestSendFailed as e: + logger.warning( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) + raise SynapseError(502, "Failed to fetch remote media") + + except HttpResponseException as e: + logger.warning( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) + if e.code == twisted.web.http.NOT_FOUND: + raise e.to_synapse_error() + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise + except NotRetryingDestination: + logger.warning("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise SynapseError(502, "Failed to fetch remote media") + + await finish() + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) + + logger.info("Stored remote media in file %r", fname) + + media_info = { + "media_type": media_type, + "media_length": length, + "upload_name": upload_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + } + + return media_info + + def _get_thumbnail_requirements( + self, media_type: str + ) -> Tuple[ThumbnailRequirement, ...]: + scpos = media_type.find(";") + if scpos > 0: + media_type = media_type[:scpos] + return self.thumbnail_requirements.get(media_type, ()) + + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, + ) + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = thumbnailer.transpose() + + if t_method == "crop": + return thumbnailer.crop(t_width, t_height, t_type) + elif t_method == "scale": + t_width, t_height = thumbnailer.aspect(t_width, t_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + return thumbnailer.scale(t_width, t_height, t_type) + + return None + + async def generate_local_exact_thumbnail( + self, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + url_cache: bool, + ) -> Optional[str]: + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(None, media_id, url_cache=url_cache) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None + + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=None, + file_id=media_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + output_path = await self.media_storage.store_file( + t_byte_source, file_info + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) + + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + return output_path + + # Could not generate thumbnail. + return None + + async def generate_remote_exact_thumbnail( + self, + server_name: str, + file_id: str, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[str]: + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None + + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + output_path = await self.media_storage.store_file( + t_byte_source, file_info + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) + + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + + return output_path + + # Could not generate thumbnail. + return None + + async def _generate_thumbnails( + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: + """Generate and store thumbnails for an image. + + Args: + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as + the file_id for local content) + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. + """ + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return None + + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, + ) + return None + + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, + ) + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose + ) + + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: Dict[Tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + with self.media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + return {"width": m_width, "height": m_height} + + async def _apply_media_retention_rules(self) -> None: + """ + Purge old local and remote media according to the media retention rules + defined in the homeserver config. + """ + # Purge remote media + if self._media_retention_remote_media_lifetime_ms is not None: + # Calculate a threshold timestamp derived from the configured lifetime. Any + # media that has not been accessed since this timestamp will be removed. + remote_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms + ) + + logger.info( + "Purging remote media last accessed before" + f" {remote_media_threshold_timestamp_ms}" + ) + + await self.delete_old_remote_media( + before_ts=remote_media_threshold_timestamp_ms + ) + + # And now do the same for local media + if self._media_retention_local_media_lifetime_ms is not None: + # This works the same as the remote media threshold + local_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_local_media_lifetime_ms + ) + + logger.info( + "Purging local media last accessed before" + f" {local_media_threshold_timestamp_ms}" + ) + + await self.delete_old_local_media( + before_ts=local_media_threshold_timestamp_ms, + keep_profiles=True, + delete_quarantined_media=False, + delete_protected_media=False, + ) + + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: + old_media = await self.store.get_remote_media_ids( + before_ts, include_quarantined_media=False + ) + + deleted = 0 + + for media in old_media: + origin = media["media_origin"] + media_id = media["media_id"] + file_id = media["filesystem_id"] + key = (origin, media_id) + + logger.info("Deleting: %r", key) + + # TODO: Should we delete from the backup store + + async with self.remote_media_linearizer.queue(key): + full_path = self.filepaths.remote_media_filepath(origin, file_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r", full_path) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(origin, media_id) + deleted += 1 + + return {"deleted": deleted} + + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete the given local or remote media ID from this server + + Args: + media_id: The media ID to delete. + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + return await self._remove_local_media_from_disk(media_ids) + + async def delete_old_local_media( + self, + before_ts: int, + size_gt: int = 0, + keep_profiles: bool = True, + delete_quarantined_media: bool = False, + delete_protected_media: bool = False, + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server by size and timestamp. Removes + media files, any thumbnails and cached URLs. + + Args: + before_ts: Unix timestamp in ms. + Files that were last used before this timestamp will be deleted. + size_gt: Size of the media in bytes. Files that are larger will be deleted. + keep_profiles: Switch to delete also files that are still used in image data + (e.g user profile, room avatar). If false these files will be deleted. + delete_quarantined_media: If True, media marked as quarantined will be deleted. + delete_protected_media: If True, media marked as protected will be deleted. + + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + old_media = await self.store.get_local_media_ids( + before_ts, + size_gt, + keep_profiles, + include_quarantined_media=delete_quarantined_media, + include_protected_media=delete_protected_media, + ) + return await self._remove_local_media_from_disk(old_media) + + async def _remove_local_media_from_disk( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server. Removes media files, + any thumbnails and cached URLs. + + Args: + media_ids: List of media_id to delete + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + removed_media = [] + for media_id in media_ids: + logger.info("Deleting media with ID '%s'", media_id) + full_path = self.filepaths.local_media_filepath(media_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r: %s", full_path, e) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(self.server_name, media_id) + + await self.store.delete_url_cache((media_id,)) + await self.store.delete_url_cache_media((media_id,)) + + removed_media.append(media_id) + + return removed_media, len(removed_media) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py new file mode 100644 index 0000000000..a7e22a91e1 --- /dev/null +++ b/synapse/media/media_storage.py @@ -0,0 +1,374 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import os +import shutil +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + Awaitable, + BinaryIO, + Callable, + Generator, + Optional, + Sequence, + Tuple, + Type, +) + +import attr + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender + +import synapse +from synapse.api.errors import NotFoundError +from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock +from synapse.util.file_consumer import BackgroundFileConsumer + +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.media.storage_provider import StorageProvider + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MediaStorage: + """Responsible for storing/fetching files from local sources. + + Args: + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. + """ + + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): + self.hs = hs + self.reactor = hs.get_reactor() + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() + + async def store_file(self, source: IO, file_info: FileInfo) -> str: + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info: Info about the file to store + + Returns: + the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + await self.write_to_file(source, f) + await finish_cb() + + return fname + + async def write_to_file(self, source: IO, output: IO) -> None: + """Asynchronously write the `source` to `output`.""" + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + + @contextlib.contextmanager + def store_into_file( + self, file_info: FileInfo + ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns an awaitable. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info: Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + await finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + + finished_called = [False] + + try: + with open(fname, "wb") as f: + + async def finish() -> None: + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + spam_check = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam_check != synapse.module_api.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + + yield f, fname, finish + except Exception as e: + try: + os.remove(fname) + except Exception: + pass + + raise e from None + + if not finished_called: + raise Exception("Finished callback not called") + + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info + + Returns: + Returns a Responder if the file was found, otherwise None. + """ + paths = [self._file_info_to_path(file_info)] + + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) + + for provider in self.storage_providers: + for path in paths: + res: Any = await provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) + + return None + + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info + + Returns: + Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + return local_path + + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + + dirname = os.path.dirname(local_path) + os.makedirs(dirname, exist_ok=True) + + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + return local_path + + raise NotFoundError() + + def _file_info_to_path(self, file_info: FileInfo) -> str: + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) + + +def _write_file_synchronously(source: IO, dest: IO) -> None: + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file: A file like object to be streamed ot the client, + is closed when finished streaming. + """ + + def __init__(self, open_file: IO): + self.open_file = open_file + + def write_to_consumer(self, consumer: IConsumer) -> Deferred: + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True, auto_attribs=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2**14 + + clock: Clock + path: str + + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: + """Reads the file in chunks and calls the callback with each chunk.""" + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py new file mode 100644 index 0000000000..c0eaf04be5 --- /dev/null +++ b/synapse/media/oembed.py @@ -0,0 +1,265 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import html +import logging +import urllib.parse +from typing import TYPE_CHECKING, List, Optional + +import attr + +from synapse.media.preview_html import parse_html_description +from synapse.types import JsonDict +from synapse.util import json_decoder + +if TYPE_CHECKING: + from lxml import etree + + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class OEmbedResult: + # The Open Graph result (converted from the oEmbed result). + open_graph_result: JsonDict + # The author_name of the oEmbed result + author_name: Optional[str] + # Number of milliseconds to cache the content, according to the oEmbed response. + # + # This will be None if no cache-age is provided in the oEmbed response (or + # if the oEmbed response cannot be turned into an Open Graph response). + cache_age: Optional[int] + + +class OEmbedProvider: + """ + A helper for accessing oEmbed content. + + It can be used to check if a URL should be accessed via oEmbed and for + requesting/parsing oEmbed content. + """ + + def __init__(self, hs: "HomeServer"): + self._oembed_patterns = {} + for oembed_endpoint in hs.config.oembed.oembed_patterns: + api_endpoint = oembed_endpoint.api_endpoint + + # Only JSON is supported at the moment. This could be declared in + # the formats field. Otherwise, if the endpoint ends in .xml assume + # it doesn't support JSON. + if ( + oembed_endpoint.formats is not None + and "json" not in oembed_endpoint.formats + ) or api_endpoint.endswith(".xml"): + logger.info( + "Ignoring oEmbed endpoint due to not supporting JSON: %s", + api_endpoint, + ) + continue + + # Iterate through each URL pattern and point it to the endpoint. + for pattern in oembed_endpoint.url_patterns: + self._oembed_patterns[pattern] = api_endpoint + + def get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Args: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in self._oembed_patterns.items(): + if url_pattern.fullmatch(url): + # TODO Specify max height / width. + + # Note that only the JSON format is supported, some endpoints want + # this in the URL, others want it as an argument. + endpoint = endpoint.replace("{format}", "json") + + args = {"url": url, "format": "json"} + query_str = urllib.parse.urlencode(args, True) + return f"{endpoint}?{query_str}" + + # No match. + return None + + def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]: + """ + Search an HTML document for oEmbed autodiscovery information. + + Args: + tree: The parsed HTML body. + + Returns: + The URL to use for oEmbed information, or None if no URL was found. + """ + # Search for link elements with the proper rel and type attributes. + for tag in tree.xpath( + "//link[@rel='alternate'][@type='application/json+oembed']" + ): + if "href" in tag.attrib: + return tag.attrib["href"] + + # Some providers (e.g. Flickr) use alternative instead of alternate. + for tag in tree.xpath( + "//link[@rel='alternative'][@type='application/json+oembed']" + ): + if "href" in tag.attrib: + return tag.attrib["href"] + + return None + + def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: + """ + Parse the oEmbed response into an Open Graph response. + + Args: + url: The URL which is being previewed (not the one which was + requested). + raw_body: The oEmbed response as JSON encoded as bytes. + + Returns: + json-encoded Open Graph data + """ + + try: + # oEmbed responses *must* be UTF-8 according to the spec. + oembed = json_decoder.decode(raw_body.decode("utf-8")) + except ValueError: + return OEmbedResult({}, None, None) + + # The version is a required string field, but not always provided, + # or sometimes provided as a float. Be lenient. + oembed_version = oembed.get("version", "1.0") + if oembed_version != "1.0" and oembed_version != 1: + return OEmbedResult({}, None, None) + + # Attempt to parse the cache age, if possible. + try: + cache_age = int(oembed.get("cache_age")) * 1000 + except (TypeError, ValueError): + # If the cache age cannot be parsed (e.g. wrong type or invalid + # string), ignore it. + cache_age = None + + # The oEmbed response converted to Open Graph. + open_graph_response: JsonDict = {"og:url": url} + + title = oembed.get("title") + if title and isinstance(title, str): + # A common WordPress plug-in seems to incorrectly escape entities + # in the oEmbed response. + open_graph_response["og:title"] = html.unescape(title) + + author_name = oembed.get("author_name") + if not isinstance(author_name, str): + author_name = None + + # Use the provider name and as the site. + provider_name = oembed.get("provider_name") + if provider_name and isinstance(provider_name, str): + open_graph_response["og:site_name"] = provider_name + + # If a thumbnail exists, use it. Note that dimensions will be calculated later. + thumbnail_url = oembed.get("thumbnail_url") + if thumbnail_url and isinstance(thumbnail_url, str): + open_graph_response["og:image"] = thumbnail_url + + # Process each type separately. + oembed_type = oembed.get("type") + if oembed_type == "rich": + html_str = oembed.get("html") + if isinstance(html_str, str): + calc_description_and_urls(open_graph_response, html_str) + + elif oembed_type == "photo": + # If this is a photo, use the full image, not the thumbnail. + url = oembed.get("url") + if url and isinstance(url, str): + open_graph_response["og:image"] = url + + elif oembed_type == "video": + open_graph_response["og:type"] = "video.other" + html_str = oembed.get("html") + if html_str and isinstance(html_str, str): + calc_description_and_urls(open_graph_response, oembed["html"]) + for size in ("width", "height"): + val = oembed.get(size) + if type(val) is int: + open_graph_response[f"og:video:{size}"] = val + + elif oembed_type == "link": + open_graph_response["og:type"] = "website" + + else: + logger.warning("Unknown oEmbed type: %s", oembed_type) + + return OEmbedResult(open_graph_response, author_name, cache_age) + + +def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: + results = [] + for tag in tree.xpath("//*/" + tag_name): + if "src" in tag.attrib: + results.append(tag.attrib["src"]) + return results + + +def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: + """ + Calculate description for an HTML document. + + This uses lxml to convert the HTML document into plaintext. If errors + occur during processing of the document, an empty response is returned. + + Args: + open_graph_response: The current Open Graph summary. This is updated with additional fields. + html_body: The HTML document, as bytes. + + Returns: + The summary + """ + # If there's no body, nothing useful is going to be found. + if not html_body: + return + + from lxml import etree + + # Create an HTML parser. If this fails, log and return no metadata. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(html_body, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return + + # Attempt to find interesting URLs (images, videos, embeds). + if "og:image" not in open_graph_response: + image_urls = _fetch_urls(tree, "img") + if image_urls: + open_graph_response["og:image"] = image_urls[0] + + video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed") + if video_urls: + open_graph_response["og:video"] = video_urls[0] + + description = parse_html_description(tree) + if description: + open_graph_response["og:description"] = description diff --git a/synapse/media/preview_html.py b/synapse/media/preview_html.py new file mode 100644 index 0000000000..516d0434f0 --- /dev/null +++ b/synapse/media/preview_html.py @@ -0,0 +1,501 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import logging +import re +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Union, +) + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + +# Certain elements aren't meant for display. +ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def _get_meta_tags( + tree: "etree.Element", + property: str, + prefix: str, + property_mapper: Optional[Callable[[str], Optional[str]]] = None, +) -> Dict[str, Optional[str]]: + """ + Search for meta tags prefixed with a particular string. + + Args: + tree: The parsed HTML document. + property: The name of the property which contains the tag name, e.g. + "property" for Open Graph. + prefix: The prefix on the property to search for, e.g. "og" for Open Graph. + property_mapper: An optional callable to map the property to the Open Graph + form. Can return None for a key to ignore that key. + + Returns: + A map of tag name to value. + """ + results: Dict[str, Optional[str]] = {} + for tag in tree.xpath( + f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" + ): + # if we've got more than 50 tags, someone is taking the piss + if len(results) >= 50: + logger.warning( + "Skipping parsing of Open Graph for page with too many '%s:' tags", + prefix, + ) + return {} + + key = tag.attrib[property] + if property_mapper: + key = property_mapper(key) + # None is a special value used to ignore a value. + if key is None: + continue + + results[key] = tag.attrib["content"] + + return results + + +def _map_twitter_to_open_graph(key: str) -> Optional[str]: + """ + Map a Twitter card property to the analogous Open Graph property. + + Args: + key: The Twitter card property (starts with "twitter:"). + + Returns: + The Open Graph property (starts with "og:") or None to have this property + be ignored. + """ + # Twitter card properties with no analogous Open Graph property. + if key == "twitter:card" or key == "twitter:creator": + return None + if key == "twitter:site": + return "og:site_name" + # Otherwise, swap twitter to og. + return "og" + key[7:] + + +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + + Returns: + The Open Graph response as a dictionary. + """ + + # Search for Open Graph (og:) meta tags, e.g.: + # + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og = _get_meta_tags(tree, "property", "og") + + # TODO: Search for properties specific to the different Open Graph types, + # such as article: meta tags, e.g.: + # + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + # Search for Twitter Card (twitter:) meta tags, e.g.: + # + # "twitter:site" : "@matrixdotorg" + # "twitter:creator" : "@matrixdotorg" + # + # Twitter cards tags also duplicate Open Graph tags. + # + # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started + twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph) + # Merge the Twitter values with the Open Graph values, but do not overwrite + # information from Open Graph tags. + for key, value in twitter.items(): + if key not in og: + og[key] = value + + if "og:title" not in og: + # Attempt to find a title from the title tag, or the biggest header on the page. + title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + if title: + og["og:title"] = title[0].strip() + else: + og["og:title"] = None + + if "og:image" not in og: + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" + ) + # If a meta image is found, use it. + if meta_image: + og["og:image"] = meta_image[0] + else: + # Try to find images which are larger than 10px by 10px. + # + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + # If no images were found, try to find *any* images. + if not images: + images = tree.xpath("//img[@src][1]") + if images: + og["og:image"] = images[0].attrib["src"] + + # Finally, fallback to the favicon if nothing else. + else: + favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + if favicons: + og["og:image"] = favicons[0] + + if "og:description" not in og: + # Check the first meta description tag for content. + meta_description = tree.xpath( + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" + ) + # If a meta description is found with content, use it. + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the tag, unless they are within + an HTML5 semantic markup tag (
,