From 32a59a6495f8d463f82ae52283159359a9961c25 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 23 Nov 2023 12:35:37 +0000 Subject: Keep track of `user_ips` and `monthly_active_users` when delegating auth (#16672) * Describe `insert_client_ip` * Pull out client_ips and MAU tracking to BaseAuth * Define HAS_AUTHLIB once in tests sick of copypasting * Track ips and token usage when delegating auth * Test that we track MAU and user_ips * Don't track `__oidc_admin` --- tests/config/test_oauth_delegation.py | 10 +-- tests/handlers/test_oauth_delegation.py | 129 ++++++++++++++++++++++++++++---- tests/rest/admin/test_jwks.py | 8 +- tests/rest/client/test_keys.py | 8 +- tests/rest/test_well_known.py | 8 +- tests/utils.py | 7 ++ 6 files changed, 126 insertions(+), 44 deletions(-) (limited to 'tests') diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py index 5c91031746..b1a9db0210 100644 --- a/tests/config/test_oauth_delegation.py +++ b/tests/config/test_oauth_delegation.py @@ -22,15 +22,7 @@ from synapse.types import JsonDict from tests.server import get_clock, setup_test_homeserver from tests.unittest import TestCase, skip_unless -from tests.utils import default_config - -try: - import authlib # noqa: F401 - - HAS_AUTHLIB = True -except ImportError: - HAS_AUTHLIB = False - +from tests.utils import HAS_AUTHLIB, default_config # These are a few constants that are used as config parameters in the tests. SERVER_NAME = "test" diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index a72ecfdc97..a2b8e3d562 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -13,7 +13,8 @@ # limitations under the License. from http import HTTPStatus -from typing import Any, Dict, Union +from io import BytesIO +from typing import Any, Dict, Optional, Union from unittest.mock import ANY, AsyncMock, Mock from urllib.parse import parse_qs @@ -25,6 +26,8 @@ from signedjson.key import ( from signedjson.sign import sign_json from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.api.errors import ( AuthError, @@ -33,23 +36,17 @@ from synapse.api.errors import ( OAuthInsufficientScopeError, SynapseError, ) +from synapse.http.site import SynapseRequest from synapse.rest import admin from synapse.rest.client import account, devices, keys, login, logout, register from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock +from tests.server import FakeChannel from tests.test_utils import FakeResponse, get_awaitable_result -from tests.unittest import HomeserverTestCase, skip_unless -from tests.utils import mock_getRawHeaders - -try: - import authlib # noqa: F401 - - HAS_AUTHLIB = True -except ImportError: - HAS_AUTHLIB = False - +from tests.unittest import HomeserverTestCase, override_config, skip_unless +from tests.utils import HAS_AUTHLIB, checked_cast, mock_getRawHeaders # These are a few constants that are used as config parameters in the tests. SERVER_NAME = "test" @@ -75,6 +72,7 @@ MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE SUBJECT = "abc-def-ghi" USERNAME = "test-user" USER_ID = "@" + USERNAME + ":" + SERVER_NAME +OIDC_ADMIN_USERID = f"@__oidc_admin:{SERVER_NAME}" async def get_json(url: str) -> JsonDict: @@ -134,7 +132,10 @@ class MSC3861OAuthDelegation(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) - self.auth = hs.get_auth() + # Import this here so that we've checked that authlib is available. + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + self.auth = checked_cast(MSC3861DelegatedAuth, hs.get_auth()) return hs @@ -675,7 +676,8 @@ class MSC3861OAuthDelegation(HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual( - requester.user.to_string(), "@%s:%s" % ("__oidc_admin", SERVER_NAME) + requester.user.to_string(), + OIDC_ADMIN_USERID, ) self.assertEqual(requester.is_guest, False) self.assertEqual(requester.device_id, None) @@ -685,3 +687,102 @@ class MSC3861OAuthDelegation(HomeserverTestCase): # There should be no call to the introspection endpoint self.http_client.request.assert_not_called() + + @override_config({"mau_stats_only": True}) + def test_request_tracking(self) -> None: + """Using an access token should update the client_ips and MAU tables.""" + # To start, there are no MAU users. + store = self.hs.get_datastores().main + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 0) + + known_token = "token-token-GOOD-:)" + + async def mock_http_client_request( + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """Mocked auth provider response.""" + assert method == "POST" + token = parse_qs(data)[b"token"][0].decode("utf-8") + if token == known_token: + return FakeResponse.json( + code=200, + payload={ + "active": True, + "scope": MATRIX_USER_SCOPE, + "sub": SUBJECT, + "username": USERNAME, + }, + ) + + return FakeResponse.json(code=200, payload={"active": False}) + + self.http_client.request = mock_http_client_request + + EXAMPLE_IPV4_ADDR = "123.123.123.123" + EXAMPLE_USER_AGENT = "httprettygood" + + # First test a known access token + channel = FakeChannel(self.site, self.reactor) + # type-ignore: FakeChannel is a mock of an HTTPChannel, not a proper HTTPChannel + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] + req.client.host = EXAMPLE_IPV4_ADDR + req.requestHeaders.addRawHeader("Authorization", f"Bearer {known_token}") + req.requestHeaders.addRawHeader("User-Agent", EXAMPLE_USER_AGENT) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/client/v3/account/whoami", + b"1.1", + ) + channel.await_result() + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(channel.json_body["user_id"], USER_ID, channel.json_body) + + # Expect to see one MAU entry, from the first request + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 1) + + conn_infos = self.get_success( + store.get_user_ip_and_agents(UserID.from_string(USER_ID)) + ) + self.assertEqual(len(conn_infos), 1, conn_infos) + conn_info = conn_infos[0] + self.assertEqual(conn_info["access_token"], known_token) + self.assertEqual(conn_info["ip"], EXAMPLE_IPV4_ADDR) + self.assertEqual(conn_info["user_agent"], EXAMPLE_USER_AGENT) + + # Now test MAS making a request using the special __oidc_admin token + MAS_IPV4_ADDR = "127.0.0.1" + MAS_USER_AGENT = "masmasmas" + + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] + req.client.host = MAS_IPV4_ADDR + req.requestHeaders.addRawHeader( + "Authorization", f"Bearer {self.auth._admin_token}" + ) + req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/client/v3/account/whoami", + b"1.1", + ) + channel.await_result() + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual( + channel.json_body["user_id"], OIDC_ADMIN_USERID, channel.json_body + ) + + # Still expect to see one MAU entry, from the first request + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 1) + + conn_infos = self.get_success( + store.get_user_ip_and_agents(UserID.from_string(OIDC_ADMIN_USERID)) + ) + self.assertEqual(conn_infos, []) diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py index a9a6191c73..842e92c3d0 100644 --- a/tests/rest/admin/test_jwks.py +++ b/tests/rest/admin/test_jwks.py @@ -19,13 +19,7 @@ from twisted.web.resource import Resource from synapse.rest.synapse.client import build_synapse_client_resource_tree from tests.unittest import HomeserverTestCase, override_config, skip_unless - -try: - import authlib # noqa: F401 - - HAS_AUTHLIB = True -except ImportError: - HAS_AUTHLIB = False +from tests.utils import HAS_AUTHLIB @skip_unless(HAS_AUTHLIB, "requires authlib") diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py index 9f81a695fa..a6023dff7a 100644 --- a/tests/rest/client/test_keys.py +++ b/tests/rest/client/test_keys.py @@ -30,13 +30,7 @@ from synapse.types import JsonDict, Requester, create_requester from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.unittest import override_config - -try: - import authlib # noqa: F401 - - HAS_AUTHLIB = True -except ImportError: - HAS_AUTHLIB = False +from tests.utils import HAS_AUTHLIB class KeyQueryTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 377243a170..7931a70abb 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -16,13 +16,7 @@ from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource from tests import unittest - -try: - import authlib # noqa: F401 - - HAS_AUTHLIB = True -except ImportError: - HAS_AUTHLIB = False +from tests.utils import HAS_AUTHLIB class WellKnownTests(unittest.HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index a0c87ad628..e0066fe15a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,13 @@ from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +try: + import authlib # noqa: F401 + + HAS_AUTHLIB = True +except ImportError: + HAS_AUTHLIB = False + # set this to True to run the tests against postgres instead of sqlite. # # When running under postgres, we first create a base database with the name -- cgit 1.5.1 From 0619c2bbd266a1e643bfe65a675afba8871aeb95 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 27 Nov 2023 01:29:46 +0000 Subject: Move media retention tests out of rest tests (#16684) * Move media retention tests out of rest tests AFAICS this doesn't make any HTTP requests and so it ought not to belong in `tests.rest`. * Changelog --- changelog.d/16684.misc | 1 + tests/media/test_media_retention.py | 294 +++++++++++++++++++++++++++++++ tests/rest/media/test_media_retention.py | 294 ------------------------------- 3 files changed, 295 insertions(+), 294 deletions(-) create mode 100644 changelog.d/16684.misc create mode 100644 tests/media/test_media_retention.py delete mode 100644 tests/rest/media/test_media_retention.py (limited to 'tests') diff --git a/changelog.d/16684.misc b/changelog.d/16684.misc new file mode 100644 index 0000000000..6fb55c08a5 --- /dev/null +++ b/changelog.d/16684.misc @@ -0,0 +1 @@ +Reoranganise test files. diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py new file mode 100644 index 0000000000..27a663a23b --- /dev/null +++ b/tests/media/test_media_retention.py @@ -0,0 +1,294 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from typing import Iterable, Optional + +from matrix_common.types.mxc_uri import MXCUri + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config +from tests.utils import MockClock + + +class MediaRetentionTestCase(unittest.HomeserverTestCase): + ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 + THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS + + servlets = [ + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # We need to be able to test advancing time in the homeserver, so we + # replace the test homeserver's default clock with a MockClock, which + # supports advancing time. + return self.setup_test_homeserver(clock=MockClock()) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.remote_server_name = "remote.homeserver" + self.store = hs.get_datastores().main + + # Create a user to upload media with + test_user_id = self.register_user("alice", "password") + + # Inject media (recently accessed, old access, never accessed, old access + # quarantined media) into both the local store and the remote cache, plus + # one additional local media that is marked as protected from quarantine. + media_repository = hs.get_media_repository() + test_media_content = b"example string" + + def _create_media_and_set_attributes( + last_accessed_ms: Optional[int], + is_quarantined: Optional[bool] = False, + is_protected: Optional[bool] = False, + ) -> MXCUri: + # "Upload" some media to the local media store + mxc_uri: MXCUri = self.get_success( + media_repository.create_content( + media_type="text/plain", + upload_name=None, + content=io.BytesIO(test_media_content), + content_length=len(test_media_content), + auth_user=UserID.from_string(test_user_id), + ) + ) + + # Set the last recently accessed time for this media + if last_accessed_ms is not None: + self.get_success( + self.store.update_cached_last_access_time( + local_media=(mxc_uri.media_id,), + remote_media=(), + time_ms=last_accessed_ms, + ) + ) + + if is_quarantined: + # Mark this media as quarantined + self.get_success( + self.store.quarantine_media_by_id( + server_name=self.hs.config.server.server_name, + media_id=mxc_uri.media_id, + quarantined_by="@theadmin:test", + ) + ) + + if is_protected: + # Mark this media as protected from quarantine + self.get_success( + self.store.mark_local_media_as_safe( + media_id=mxc_uri.media_id, + safe=True, + ) + ) + + return mxc_uri + + def _cache_remote_media_and_set_attributes( + media_id: str, + last_accessed_ms: Optional[int], + is_quarantined: Optional[bool] = False, + ) -> MXCUri: + # Pretend to cache some remote media + self.get_success( + self.store.store_cached_remote_media( + origin=self.remote_server_name, + media_id=media_id, + media_type="text/plain", + media_length=1, + time_now_ms=clock.time_msec(), + upload_name="testfile.txt", + filesystem_id="abcdefg12345", + ) + ) + + # Set the last recently accessed time for this media + if last_accessed_ms is not None: + self.get_success( + hs.get_datastores().main.update_cached_last_access_time( + local_media=(), + remote_media=((self.remote_server_name, media_id),), + time_ms=last_accessed_ms, + ) + ) + + if is_quarantined: + # Mark this media as quarantined + self.get_success( + self.store.quarantine_media_by_id( + server_name=self.remote_server_name, + media_id=media_id, + quarantined_by="@theadmin:test", + ) + ) + + return MXCUri(self.remote_server_name, media_id) + + # Start with the local media store + self.local_recently_accessed_media = _create_media_and_set_attributes( + last_accessed_ms=self.THIRTY_DAYS_IN_MS, + ) + self.local_not_recently_accessed_media = _create_media_and_set_attributes( + last_accessed_ms=self.ONE_DAY_IN_MS, + ) + self.local_not_recently_accessed_quarantined_media = ( + _create_media_and_set_attributes( + last_accessed_ms=self.ONE_DAY_IN_MS, + is_quarantined=True, + ) + ) + self.local_not_recently_accessed_protected_media = ( + _create_media_and_set_attributes( + last_accessed_ms=self.ONE_DAY_IN_MS, + is_protected=True, + ) + ) + self.local_never_accessed_media = _create_media_and_set_attributes( + last_accessed_ms=None, + ) + + # And now the remote media store + self.remote_recently_accessed_media = _cache_remote_media_and_set_attributes( + media_id="a", + last_accessed_ms=self.THIRTY_DAYS_IN_MS, + ) + self.remote_not_recently_accessed_media = ( + _cache_remote_media_and_set_attributes( + media_id="b", + last_accessed_ms=self.ONE_DAY_IN_MS, + ) + ) + self.remote_not_recently_accessed_quarantined_media = ( + _cache_remote_media_and_set_attributes( + media_id="c", + last_accessed_ms=self.ONE_DAY_IN_MS, + is_quarantined=True, + ) + ) + # Remote media will always have a "last accessed" attribute, as it would not + # be fetched from the remote homeserver unless instigated by a user. + + @override_config( + { + "media_retention": { + # Enable retention for local media + "local_media_lifetime": "30d" + # Cached remote media should not be purged + } + } + ) + def test_local_media_retention(self) -> None: + """ + Tests that local media that have not been accessed recently is purged, while + cached remote media is unaffected. + """ + # Advance 31 days (in seconds) + self.reactor.advance(31 * 24 * 60 * 60) + + # Check that media has been correctly purged. + # Local media accessed <30 days ago should still exist. + # Remote media should be unaffected. + self._assert_if_mxc_uris_purged( + purged=[ + self.local_not_recently_accessed_media, + self.local_never_accessed_media, + ], + not_purged=[ + self.local_recently_accessed_media, + self.local_not_recently_accessed_quarantined_media, + self.local_not_recently_accessed_protected_media, + self.remote_recently_accessed_media, + self.remote_not_recently_accessed_media, + self.remote_not_recently_accessed_quarantined_media, + ], + ) + + @override_config( + { + "media_retention": { + # Enable retention for cached remote media + "remote_media_lifetime": "30d" + # Local media should not be purged + } + } + ) + def test_remote_media_cache_retention(self) -> None: + """ + Tests that entries from the remote media cache that have not been accessed + recently is purged, while local media is unaffected. + """ + # Advance 31 days (in seconds) + self.reactor.advance(31 * 24 * 60 * 60) + + # Check that media has been correctly purged. + # Local media should be unaffected. + # Remote media accessed <30 days ago should still exist. + self._assert_if_mxc_uris_purged( + purged=[ + self.remote_not_recently_accessed_media, + ], + not_purged=[ + self.remote_recently_accessed_media, + self.local_recently_accessed_media, + self.local_not_recently_accessed_media, + self.local_not_recently_accessed_quarantined_media, + self.local_not_recently_accessed_protected_media, + self.remote_not_recently_accessed_quarantined_media, + self.local_never_accessed_media, + ], + ) + + def _assert_if_mxc_uris_purged( + self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri] + ) -> None: + def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: + """Given an MXC URI, assert whether it has been purged or not.""" + if mxc_uri.server_name == self.hs.config.server.server_name: + found_media = bool( + self.get_success(self.store.get_local_media(mxc_uri.media_id)) + ) + else: + found_media = bool( + self.get_success( + self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) + ) + ) + + if expect_purged: + self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged") + else: + self.assertTrue( + found_media, + msg=f"{mxc_uri} unexpectedly purged", + ) + + # Assert that the given MXC URIs have either been correctly purged or not. + for mxc_uri in purged: + _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True) + for mxc_uri in not_purged: + _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py deleted file mode 100644 index 27a663a23b..0000000000 --- a/tests/rest/media/test_media_retention.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -from typing import Iterable, Optional - -from matrix_common.types.mxc_uri import MXCUri - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.rest import admin -from synapse.rest.client import login, register, room -from synapse.server import HomeServer -from synapse.types import UserID -from synapse.util import Clock - -from tests import unittest -from tests.unittest import override_config -from tests.utils import MockClock - - -class MediaRetentionTestCase(unittest.HomeserverTestCase): - ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 - THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS - - servlets = [ - room.register_servlets, - login.register_servlets, - register.register_servlets, - admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - # We need to be able to test advancing time in the homeserver, so we - # replace the test homeserver's default clock with a MockClock, which - # supports advancing time. - return self.setup_test_homeserver(clock=MockClock()) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.remote_server_name = "remote.homeserver" - self.store = hs.get_datastores().main - - # Create a user to upload media with - test_user_id = self.register_user("alice", "password") - - # Inject media (recently accessed, old access, never accessed, old access - # quarantined media) into both the local store and the remote cache, plus - # one additional local media that is marked as protected from quarantine. - media_repository = hs.get_media_repository() - test_media_content = b"example string" - - def _create_media_and_set_attributes( - last_accessed_ms: Optional[int], - is_quarantined: Optional[bool] = False, - is_protected: Optional[bool] = False, - ) -> MXCUri: - # "Upload" some media to the local media store - mxc_uri: MXCUri = self.get_success( - media_repository.create_content( - media_type="text/plain", - upload_name=None, - content=io.BytesIO(test_media_content), - content_length=len(test_media_content), - auth_user=UserID.from_string(test_user_id), - ) - ) - - # Set the last recently accessed time for this media - if last_accessed_ms is not None: - self.get_success( - self.store.update_cached_last_access_time( - local_media=(mxc_uri.media_id,), - remote_media=(), - time_ms=last_accessed_ms, - ) - ) - - if is_quarantined: - # Mark this media as quarantined - self.get_success( - self.store.quarantine_media_by_id( - server_name=self.hs.config.server.server_name, - media_id=mxc_uri.media_id, - quarantined_by="@theadmin:test", - ) - ) - - if is_protected: - # Mark this media as protected from quarantine - self.get_success( - self.store.mark_local_media_as_safe( - media_id=mxc_uri.media_id, - safe=True, - ) - ) - - return mxc_uri - - def _cache_remote_media_and_set_attributes( - media_id: str, - last_accessed_ms: Optional[int], - is_quarantined: Optional[bool] = False, - ) -> MXCUri: - # Pretend to cache some remote media - self.get_success( - self.store.store_cached_remote_media( - origin=self.remote_server_name, - media_id=media_id, - media_type="text/plain", - media_length=1, - time_now_ms=clock.time_msec(), - upload_name="testfile.txt", - filesystem_id="abcdefg12345", - ) - ) - - # Set the last recently accessed time for this media - if last_accessed_ms is not None: - self.get_success( - hs.get_datastores().main.update_cached_last_access_time( - local_media=(), - remote_media=((self.remote_server_name, media_id),), - time_ms=last_accessed_ms, - ) - ) - - if is_quarantined: - # Mark this media as quarantined - self.get_success( - self.store.quarantine_media_by_id( - server_name=self.remote_server_name, - media_id=media_id, - quarantined_by="@theadmin:test", - ) - ) - - return MXCUri(self.remote_server_name, media_id) - - # Start with the local media store - self.local_recently_accessed_media = _create_media_and_set_attributes( - last_accessed_ms=self.THIRTY_DAYS_IN_MS, - ) - self.local_not_recently_accessed_media = _create_media_and_set_attributes( - last_accessed_ms=self.ONE_DAY_IN_MS, - ) - self.local_not_recently_accessed_quarantined_media = ( - _create_media_and_set_attributes( - last_accessed_ms=self.ONE_DAY_IN_MS, - is_quarantined=True, - ) - ) - self.local_not_recently_accessed_protected_media = ( - _create_media_and_set_attributes( - last_accessed_ms=self.ONE_DAY_IN_MS, - is_protected=True, - ) - ) - self.local_never_accessed_media = _create_media_and_set_attributes( - last_accessed_ms=None, - ) - - # And now the remote media store - self.remote_recently_accessed_media = _cache_remote_media_and_set_attributes( - media_id="a", - last_accessed_ms=self.THIRTY_DAYS_IN_MS, - ) - self.remote_not_recently_accessed_media = ( - _cache_remote_media_and_set_attributes( - media_id="b", - last_accessed_ms=self.ONE_DAY_IN_MS, - ) - ) - self.remote_not_recently_accessed_quarantined_media = ( - _cache_remote_media_and_set_attributes( - media_id="c", - last_accessed_ms=self.ONE_DAY_IN_MS, - is_quarantined=True, - ) - ) - # Remote media will always have a "last accessed" attribute, as it would not - # be fetched from the remote homeserver unless instigated by a user. - - @override_config( - { - "media_retention": { - # Enable retention for local media - "local_media_lifetime": "30d" - # Cached remote media should not be purged - } - } - ) - def test_local_media_retention(self) -> None: - """ - Tests that local media that have not been accessed recently is purged, while - cached remote media is unaffected. - """ - # Advance 31 days (in seconds) - self.reactor.advance(31 * 24 * 60 * 60) - - # Check that media has been correctly purged. - # Local media accessed <30 days ago should still exist. - # Remote media should be unaffected. - self._assert_if_mxc_uris_purged( - purged=[ - self.local_not_recently_accessed_media, - self.local_never_accessed_media, - ], - not_purged=[ - self.local_recently_accessed_media, - self.local_not_recently_accessed_quarantined_media, - self.local_not_recently_accessed_protected_media, - self.remote_recently_accessed_media, - self.remote_not_recently_accessed_media, - self.remote_not_recently_accessed_quarantined_media, - ], - ) - - @override_config( - { - "media_retention": { - # Enable retention for cached remote media - "remote_media_lifetime": "30d" - # Local media should not be purged - } - } - ) - def test_remote_media_cache_retention(self) -> None: - """ - Tests that entries from the remote media cache that have not been accessed - recently is purged, while local media is unaffected. - """ - # Advance 31 days (in seconds) - self.reactor.advance(31 * 24 * 60 * 60) - - # Check that media has been correctly purged. - # Local media should be unaffected. - # Remote media accessed <30 days ago should still exist. - self._assert_if_mxc_uris_purged( - purged=[ - self.remote_not_recently_accessed_media, - ], - not_purged=[ - self.remote_recently_accessed_media, - self.local_recently_accessed_media, - self.local_not_recently_accessed_media, - self.local_not_recently_accessed_quarantined_media, - self.local_not_recently_accessed_protected_media, - self.remote_not_recently_accessed_quarantined_media, - self.local_never_accessed_media, - ], - ) - - def _assert_if_mxc_uris_purged( - self, purged: Iterable[MXCUri], not_purged: Iterable[MXCUri] - ) -> None: - def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: - """Given an MXC URI, assert whether it has been purged or not.""" - if mxc_uri.server_name == self.hs.config.server.server_name: - found_media = bool( - self.get_success(self.store.get_local_media(mxc_uri.media_id)) - ) - else: - found_media = bool( - self.get_success( - self.store.get_cached_remote_media( - mxc_uri.server_name, mxc_uri.media_id - ) - ) - ) - - if expect_purged: - self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged") - else: - self.assertTrue( - found_media, - msg=f"{mxc_uri} unexpectedly purged", - ) - - # Assert that the given MXC URIs have either been correctly purged or not. - for mxc_uri in purged: - _assert_mxc_uri_purge_state(mxc_uri, expect_purged=True) - for mxc_uri in not_purged: - _assert_mxc_uri_purge_state(mxc_uri, expect_purged=False) -- cgit 1.5.1 From d6c3b7584fc46571e65226793304df35d7081534 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Nov 2023 14:03:42 -0500 Subject: Request & follow redirects for /media/v3/download (#16701) Implement MSC3860 to follow redirects for federated media downloads. Note that the Client-Server API doesn't support this (yet) since the media repository in Synapse doesn't have a way of supporting redirects. --- changelog.d/16701.feature | 1 + synapse/federation/federation_client.py | 38 +++++++++++++++ synapse/federation/transport/client.py | 53 ++++++++++++++++++++ synapse/http/matrixfederationclient.py | 77 ++++++++++++++++++++++-------- synapse/media/media_repository.py | 17 ++----- tests/media/test_media_storage.py | 62 ++++++++++++++++++++++-- tests/replication/test_multi_media_repo.py | 2 +- 7 files changed, 212 insertions(+), 38 deletions(-) create mode 100644 changelog.d/16701.feature (limited to 'tests') diff --git a/changelog.d/16701.feature b/changelog.d/16701.feature new file mode 100644 index 0000000000..2a66fc932a --- /dev/null +++ b/changelog.d/16701.feature @@ -0,0 +1 @@ +Follow redirects when downloading media over federation (per [MSC3860](https://github.com/matrix-org/matrix-spec-proposals/pull/3860)). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 1a7fa175ec..0ba03b0d05 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -21,6 +21,7 @@ from typing import ( TYPE_CHECKING, AbstractSet, Awaitable, + BinaryIO, Callable, Collection, Container, @@ -1862,6 +1863,43 @@ class FederationClient(FederationBase): return filtered_statuses, filtered_failures + async def download_media( + self, + destination: str, + media_id: str, + output_stream: BinaryIO, + max_size: int, + max_timeout_ms: int, + ) -> Tuple[int, Dict[bytes, List[bytes]]]: + try: + return await self.transport_layer.download_media_v3( + destination, + media_id, + output_stream=output_stream, + max_size=max_size, + max_timeout_ms=max_timeout_ms, + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the r0 endpoint. Otherwise, consider it a legitimate error + # and raise. + if not is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't download media %s/%s with the v3 API, falling back to the r0 API", + destination, + media_id, + ) + + return await self.transport_layer.download_media_r0( + destination, + media_id, + output_stream=output_stream, + max_size=max_size, + max_timeout_ms=max_timeout_ms, + ) + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index fab4800717..5e36638b0a 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -18,6 +18,7 @@ import urllib from typing import ( TYPE_CHECKING, Any, + BinaryIO, Callable, Collection, Dict, @@ -804,6 +805,58 @@ class TransportLayerClient: destination=destination, path=path, data={"user_ids": user_ids} ) + async def download_media_r0( + self, + destination: str, + media_id: str, + output_stream: BinaryIO, + max_size: int, + max_timeout_ms: int, + ) -> Tuple[int, Dict[bytes, List[bytes]]]: + path = f"/_matrix/media/r0/download/{destination}/{media_id}" + + return await self.client.get_file( + destination, + path, + output_stream=output_stream, + max_size=max_size, + args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false", + "timeout_ms": str(max_timeout_ms), + }, + ) + + async def download_media_v3( + self, + destination: str, + media_id: str, + output_stream: BinaryIO, + max_size: int, + max_timeout_ms: int, + ) -> Tuple[int, Dict[bytes, List[bytes]]]: + path = f"/_matrix/media/v3/download/{destination}/{media_id}" + + return await self.client.get_file( + destination, + path, + output_stream=output_stream, + max_size=max_size, + args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false", + "timeout_ms": str(max_timeout_ms), + # Matrix 1.7 allows for this to redirect to another URL, this should + # just be ignored for an old homeserver, so always provide it. + "allow_redirect": "true", + }, + follow_redirects=True, + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index d5013e8e97..cc1db763ae 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -153,12 +153,18 @@ class MatrixFederationRequest: """Query arguments. """ - txn_id: Optional[str] = None - """Unique ID for this request (for logging) + txn_id: str = attr.ib(init=False) + """Unique ID for this request (for logging), this is autogenerated. """ - uri: bytes = attr.ib(init=False) - """The URI of this request + uri: bytes = b"" + """The URI of this request, usually generated from the above information. + """ + + _generate_uri: bool = True + """True to automatically generate the uri field based on the above information. + + Set to False if manually configuring the URI. """ def __attrs_post_init__(self) -> None: @@ -168,22 +174,23 @@ class MatrixFederationRequest: object.__setattr__(self, "txn_id", txn_id) - destination_bytes = self.destination.encode("ascii") - path_bytes = self.path.encode("ascii") - query_bytes = encode_query_args(self.query) - - # The object is frozen so we can pre-compute this. - uri = urllib.parse.urlunparse( - ( - b"matrix-federation", - destination_bytes, - path_bytes, - None, - query_bytes, - b"", + if self._generate_uri: + destination_bytes = self.destination.encode("ascii") + path_bytes = self.path.encode("ascii") + query_bytes = encode_query_args(self.query) + + # The object is frozen so we can pre-compute this. + uri = urllib.parse.urlunparse( + ( + b"matrix-federation", + destination_bytes, + path_bytes, + None, + query_bytes, + b"", + ) ) - ) - object.__setattr__(self, "uri", uri) + object.__setattr__(self, "uri", uri) def get_json(self) -> Optional[JsonDict]: if self.json_callback: @@ -513,6 +520,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, backoff_on_404: bool = False, backoff_on_all_error_codes: bool = False, + follow_redirects: bool = False, ) -> IResponse: """ Sends a request to the given server. @@ -555,6 +563,9 @@ class MatrixFederationHttpClient: backoff_on_404: Back off if we get a 404 backoff_on_all_error_codes: Back off if we get any error response + follow_redirects: True to follow the Location header of 307/308 redirect + responses. This does not recurse. + Returns: Resolves with the HTTP response object on success. @@ -714,6 +725,26 @@ class MatrixFederationHttpClient: response.code, response_phrase, ) + elif ( + response.code in (307, 308) + and follow_redirects + and response.headers.hasHeader("Location") + ): + # The Location header *might* be relative so resolve it. + location = response.headers.getRawHeaders(b"Location")[0] + new_uri = urllib.parse.urljoin(request.uri, location) + + return await self._send_request( + attr.evolve(request, uri=new_uri, generate_uri=False), + retry_on_dns_fail, + timeout, + long_retries, + ignore_backoff, + backoff_on_404, + backoff_on_all_error_codes, + # Do not continue following redirects. + follow_redirects=False, + ) else: logger.info( "{%s} [%s] Got response headers: %d %s", @@ -1383,6 +1414,7 @@ class MatrixFederationHttpClient: retry_on_dns_fail: bool = True, max_size: Optional[int] = None, ignore_backoff: bool = False, + follow_redirects: bool = False, ) -> Tuple[int, Dict[bytes, List[bytes]]]: """GETs a file from a given homeserver Args: @@ -1392,6 +1424,8 @@ class MatrixFederationHttpClient: args: Optional dictionary used to create the query string. ignore_backoff: true to ignore the historical backoff data and try the request anyway. + follow_redirects: True to follow the Location header of 307/308 redirect + responses. This does not recurse. Returns: Resolves with an (int,dict) tuple of @@ -1412,7 +1446,10 @@ class MatrixFederationHttpClient: ) response = await self._send_request( - request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff + request, + retry_on_dns_fail=retry_on_dns_fail, + ignore_backoff=ignore_backoff, + follow_redirects=follow_redirects, ) headers = dict(response.headers.getAllRawHeaders()) diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index bf976b9e7c..d62af22adb 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -77,7 +77,7 @@ class MediaRepository: def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.client = hs.get_federation_http_client() + self.client = hs.get_federation_client() self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastores().main @@ -644,22 +644,13 @@ class MediaRepository: file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): - request_path = "/".join( - ("/_matrix/media/r0/download", server_name, media_id) - ) try: - length, headers = await self.client.get_file( + length, headers = await self.client.download_media( server_name, - request_path, + media_id, output_stream=f, max_size=self.max_upload_size, - args={ - # tell the remote server to 404 if it doesn't - # recognise the server_name, to make sure we don't - # end up with a routing loop. - "allow_remote": "false", - "timeout_ms": str(max_timeout_ms), - }, + max_timeout_ms=max_timeout_ms, ) except RequestSendFailed as e: logger.warning( diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index f262304c3d..f981d1c0d8 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -27,10 +27,11 @@ from typing_extensions import Literal from twisted.internet import defer from twisted.internet.defer import Deferred +from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable @@ -247,6 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): retry_on_dns_fail: bool = True, max_size: Optional[int] = None, ignore_backoff: bool = False, + follow_redirects: bool = False, ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": """A mock for MatrixFederationHttpClient.get_file.""" @@ -257,10 +259,15 @@ class MediaRepoTests(unittest.HomeserverTestCase): output_stream.write(data) return response + def write_err(f: Failure) -> Failure: + f.trap(HttpResponseException) + output_stream.write(f.value.response) + return f + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() self.fetches.append((d, destination, path, args)) # Note that this callback changes the value held by d. - d_after_callback = d.addCallback(write_to) + d_after_callback = d.addCallbacks(write_to, write_err) return make_deferred_yieldable(d_after_callback) # Mock out the homeserver's MatrixFederationHttpClient @@ -316,10 +323,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(len(self.fetches), 1) self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual( - self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id + self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id ) self.assertEqual( - self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"} + self.fetches[0][3], + {"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"}, ) headers = { @@ -671,6 +679,52 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"cross-origin"], ) + def test_unknown_v3_endpoint(self) -> None: + """ + If the v3 endpoint fails, try the r0 one. + """ + channel = self.make_request( + "GET", + f"/_matrix/media/v3/download/{self.media_id}", + shorthand=False, + await_result=False, + ) + self.pump() + + # We've made one fetch, to example.com, using the media URL, and asking + # the other server not to do a remote fetch + self.assertEqual(len(self.fetches), 1) + self.assertEqual(self.fetches[0][1], "example.com") + self.assertEqual( + self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id + ) + + # The result which says the endpoint is unknown. + unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' + self.fetches[0][0].errback( + HttpResponseException(404, "NOT FOUND", unknown_endpoint) + ) + + self.pump() + + # There should now be another request to the r0 URL. + self.assertEqual(len(self.fetches), 2) + self.assertEqual(self.fetches[1][1], "example.com") + self.assertEqual( + self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}" + ) + + headers = { + b"Content-Length": [b"%d" % (len(self.test_image.data))], + } + + self.fetches[1][0].callback( + (self.test_image.data, (len(self.test_image.data), headers)) + ) + + self.pump() + self.assertEqual(channel.code, 200) + class TestSpamCheckerLegacy: """A spam checker module that rejects all media that includes the bytes diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 1e9994cc0b..9a7b675f54 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -133,7 +133,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - f"/_matrix/media/r0/download/{target}/{media_id}".encode(), + f"/_matrix/media/v3/download/{target}/{media_id}".encode(), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] -- cgit 1.5.1