diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 87c929eb20..8bc92305fe 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -430,6 +430,7 @@ class MediaRepository:
media_id: str,
name: Optional[str],
max_timeout_ms: int,
+ allow_authenticated: bool = True,
federation: bool = False,
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -442,6 +443,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ allow_authenticated: whether media marked as authenticated may be served to this request
federation: whether the local media being fetched is for a federation request
Returns:
@@ -451,6 +453,10 @@ class MediaRepository:
if not media_info:
return
+ if self.hs.config.media.enable_authenticated_media and not allow_authenticated:
+ if media_info.authenticated:
+ raise NotFoundError()
+
self.mark_recently_accessed(None, media_id)
media_type = media_info.media_type
@@ -481,6 +487,7 @@ class MediaRepository:
max_timeout_ms: int,
ip_address: str,
use_federation_endpoint: bool,
+ allow_authenticated: bool = True,
) -> None:
"""Respond to requests for remote media.
@@ -495,6 +502,8 @@ class MediaRepository:
ip_address: the IP address of the requester
use_federation_endpoint: whether to request the remote media over the new
federation `/download` endpoint
+ allow_authenticated: whether media marked as authenticated may be served to this
+ request
Returns:
Resolves once a response has successfully been written to request
@@ -526,6 +535,7 @@ class MediaRepository:
self.download_ratelimiter,
ip_address,
use_federation_endpoint,
+ allow_authenticated,
)
# We deliberately stream the file outside the lock
@@ -548,6 +558,7 @@ class MediaRepository:
max_timeout_ms: int,
ip_address: str,
use_federation: bool,
+ allow_authenticated: bool,
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -560,6 +571,8 @@ class MediaRepository:
ip_address: IP address of the requester
use_federation: if a download is necessary, whether to request the remote file
over the federation `/download` endpoint
+ allow_authenticated: whether media marked as authenticated may be served to this
+ request
Returns:
The media info of the file
@@ -581,6 +594,7 @@ class MediaRepository:
self.download_ratelimiter,
ip_address,
use_federation,
+ allow_authenticated,
)
# Ensure we actually use the responder so that it releases resources
@@ -598,6 +612,7 @@ class MediaRepository:
download_ratelimiter: Ratelimiter,
ip_address: str,
use_federation_endpoint: bool,
+ allow_authenticated: bool,
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -619,6 +634,11 @@ class MediaRepository:
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
+ if self.hs.config.media.enable_authenticated_media and not allow_authenticated:
+ # if it isn't cached then don't fetch it or if it's authenticated then don't serve it
+ if not media_info or media_info.authenticated:
+ raise NotFoundError()
+
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise generate a new
# one.
@@ -792,6 +812,11 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname)
+ if self.hs.config.media.enable_authenticated_media:
+ authenticated = True
+ else:
+ authenticated = False
+
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
@@ -802,6 +827,7 @@ class MediaRepository:
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
+ authenticated=authenticated,
)
async def _federation_download_remote_file(
@@ -915,6 +941,11 @@ class MediaRepository:
logger.debug("Stored remote media in file %r", fname)
+ if self.hs.config.media.enable_authenticated_media:
+ authenticated = True
+ else:
+ authenticated = False
+
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
@@ -925,6 +956,7 @@ class MediaRepository:
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
+ authenticated=authenticated,
)
def _get_thumbnail_requirements(
@@ -1030,7 +1062,12 @@ class MediaRepository:
t_len = os.path.getsize(output_path)
await self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
+ media_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
)
return output_path
|