summary refs log tree commit diff
path: root/synapse/rest/media
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/media')
-rw-r--r--synapse/rest/media/v1/media_repository.py69
-rw-r--r--synapse/rest/media/v1/media_storage.py8
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py55
-rw-r--r--synapse/rest/media/v1/storage_provider.py31
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py2
5 files changed, 114 insertions, 51 deletions
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 578d7f07c5..bb79599379 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -27,15 +27,14 @@ from .identicon_resource import IdenticonResource
 from .preview_url_resource import PreviewUrlResource
 from .filepath import MediaFilePaths
 from .thumbnailer import Thumbnailer
-from .storage_provider import (
-    StorageProviderWrapper, FileStorageProviderBackend,
-)
+from .storage_provider import StorageProviderWrapper
 from .media_storage import MediaStorage
 
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError, HttpResponseException, \
-    NotFoundError
+from synapse.api.errors import (
+    SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
 
 from synapse.util.async import Linearizer
 from synapse.util.stringutils import is_ascii
@@ -53,7 +52,7 @@ import urlparse
 logger = logging.getLogger(__name__)
 
 
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository(object):
@@ -75,22 +74,21 @@ class MediaRepository(object):
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
         self.recently_accessed_remotes = set()
+        self.recently_accessed_locals = set()
+
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
         # List of StorageProviders where we should search for media and
         # potentially upload to.
         storage_providers = []
 
-        # TODO: Move this into config and allow other storage providers to be
-        # defined.
-        if hs.config.backup_media_store_path:
-            backend = FileStorageProviderBackend(
-                self.primary_base_path, hs.config.backup_media_store_path,
-            )
+        for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+            backend = clz(hs, provider_config)
             provider = StorageProviderWrapper(
                 backend,
-                store=True,
-                store_synchronous=hs.config.synchronous_backup_media_store,
-                store_remote=True,
+                store_local=wrapper_config.store_local,
+                store_remote=wrapper_config.store_remote,
+                store_synchronous=wrapper_config.store_synchronous,
             )
             storage_providers.append(provider)
 
@@ -99,19 +97,34 @@ class MediaRepository(object):
         )
 
         self.clock.looping_call(
-            self._update_recently_accessed_remotes,
-            UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+            self._update_recently_accessed,
+            UPDATE_RECENTLY_ACCESSED_TS,
         )
 
     @defer.inlineCallbacks
-    def _update_recently_accessed_remotes(self):
-        media = self.recently_accessed_remotes
+    def _update_recently_accessed(self):
+        remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
+        local_media = self.recently_accessed_locals
+        self.recently_accessed_locals = set()
+
         yield self.store.update_cached_last_access_time(
-            media, self.clock.time_msec()
+            local_media, remote_media, self.clock.time_msec()
         )
 
+    def mark_recently_accessed(self, server_name, media_id):
+        """Mark the given media as recently accessed.
+
+        Args:
+            server_name (str|None): Origin server of media, or None if local
+            media_id (str): The media ID of the content
+        """
+        if server_name:
+            self.recently_accessed_remotes.add((server_name, media_id))
+        else:
+            self.recently_accessed_locals.add(media_id)
+
     @defer.inlineCallbacks
     def create_content(self, media_type, upload_name, content, content_length,
                        auth_user):
@@ -173,6 +186,8 @@ class MediaRepository(object):
             respond_404(request)
             return
 
+        self.mark_recently_accessed(None, media_id)
+
         media_type = media_info["media_type"]
         media_length = media_info["media_length"]
         upload_name = name if name else media_info["upload_name"]
@@ -204,7 +219,13 @@ class MediaRepository(object):
             Deferred: Resolves once a response has successfully been written
                 to request
         """
-        self.recently_accessed_remotes.add((server_name, media_id))
+        if (
+            self.federation_domain_whitelist is not None and
+            server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        self.mark_recently_accessed(server_name, media_id)
 
         # We linearize here to ensure that we don't try and download remote
         # media multiple times concurrently
@@ -238,6 +259,12 @@ class MediaRepository(object):
         Returns:
             Deferred[dict]: The media_info of the file
         """
+        if (
+            self.federation_domain_whitelist is not None and
+            server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
         # We linearize here to ensure that we don't try and download remote
         # media multiple times concurrently
         key = (server_name, media_id)
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 5c3e4e5a65..e8e8b3986d 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -197,6 +197,14 @@ class MediaStorage(object):
             str
         """
         if file_info.url_cache:
+            if file_info.thumbnail:
+                return self.filepaths.url_cache_thumbnail_rel(
+                    media_id=file_info.file_id,
+                    width=file_info.thumbnail_width,
+                    height=file_info.thumbnail_height,
+                    content_type=file_info.thumbnail_type,
+                    method=file_info.thumbnail_method,
+                )
             return self.filepaths.url_cache_filepath_rel(file_info.file_id)
 
         if file_info.server_name:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 981f01e417..31fe7aa75c 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,6 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import ujson as json
+import urlparse
 
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
@@ -33,18 +46,6 @@ from synapse.http.server import (
 from synapse.util.async import ObservableDeferred
 from synapse.util.stringutils import is_ascii
 
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
-import datetime
-import errno
-import shutil
-
-import logging
 logger = logging.getLogger(__name__)
 
 
@@ -286,17 +287,28 @@ class PreviewUrlResource(Resource):
             url_cache=True,
         )
 
-        try:
-            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            try:
                 logger.debug("Trying to get url '%s'" % url)
                 length, headers, uri, code = yield self.client.get_file(
                     url, output_stream=f, max_size=self.max_spider_size,
                 )
+            except Exception as e:
                 # FIXME: pass through 404s and other error messages nicely
+                logger.warn("Error downloading %s: %r", url, e)
+                raise SynapseError(
+                    500, "Failed to download content: %s" % (
+                        traceback.format_exception_only(sys.exc_type, e),
+                    ),
+                    Codes.UNKNOWN,
+                )
+            yield finish()
 
-                yield finish()
-
-            media_type = headers["Content-Type"][0]
+        try:
+            if "Content-Type" in headers:
+                media_type = headers["Content-Type"][0]
+            else:
+                media_type = "application/octet-stream"
             time_now_ms = self.clock.time_msec()
 
             content_disposition = headers.get("Content-Disposition", None)
@@ -336,10 +348,11 @@ class PreviewUrlResource(Resource):
             )
 
         except Exception as e:
-            raise SynapseError(
-                500, ("Failed to download content: %s" % e),
-                Codes.UNKNOWN
-            )
+            logger.error("Error handling downloaded %s: %r", url, e)
+            # TODO: we really ought to delete the downloaded file in this
+            # case, since we won't have recorded it in the db, and will
+            # therefore not expire it.
+            raise
 
         defer.returnValue({
             "media_type": media_type,
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 2ad602e101..c188192f2b 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -17,6 +17,7 @@ from twisted.internet import defer, threads
 
 from .media_storage import FileResponder
 
+from synapse.config._base import Config
 from synapse.util.logcontext import preserve_fn
 
 import logging
@@ -64,19 +65,19 @@ class StorageProviderWrapper(StorageProvider):
 
     Args:
         backend (StorageProvider)
-        store (bool): Whether to store new files or not.
+        store_local (bool): Whether to store new local files or not.
         store_synchronous (bool): Whether to wait for file to be successfully
             uploaded, or todo the upload in the backgroud.
         store_remote (bool): Whether remote media should be uploaded
     """
-    def __init__(self, backend, store, store_synchronous, store_remote):
+    def __init__(self, backend, store_local, store_synchronous, store_remote):
         self.backend = backend
-        self.store = store
+        self.store_local = store_local
         self.store_synchronous = store_synchronous
         self.store_remote = store_remote
 
     def store_file(self, path, file_info):
-        if not self.store:
+        if not file_info.server_name and not self.store_local:
             return defer.succeed(None)
 
         if file_info.server_name and not self.store_remote:
@@ -97,13 +98,13 @@ class FileStorageProviderBackend(StorageProvider):
     """A storage provider that stores files in a directory on a filesystem.
 
     Args:
-        cache_directory (str): Base path of the local media repository
-        base_directory (str): Base path to store new files
+        hs (HomeServer)
+        config: The config returned by `parse_config`.
     """
 
-    def __init__(self, cache_directory, base_directory):
-        self.cache_directory = cache_directory
-        self.base_directory = base_directory
+    def __init__(self, hs, config):
+        self.cache_directory = hs.config.media_store_path
+        self.base_directory = config
 
     def store_file(self, path, file_info):
         """See StorageProvider.store_file"""
@@ -125,3 +126,15 @@ class FileStorageProviderBackend(StorageProvider):
         backup_fname = os.path.join(self.base_directory, path)
         if os.path.isfile(backup_fname):
             return FileResponder(open(backup_fname, "rb"))
+
+    @staticmethod
+    def parse_config(config):
+        """Called on startup to parse config supplied. This should parse
+        the config and raise if there is a problem.
+
+        The returned value is passed into the constructor.
+
+        In this case we only care about a single param, the directory, so let's
+        just pull that out.
+        """
+        return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 49e4af514c..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -67,6 +67,7 @@ class ThumbnailResource(Resource):
                 yield self._respond_local_thumbnail(
                     request, media_id, width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(None, media_id)
         else:
             if self.dynamic_thumbnails:
                 yield self._select_or_generate_remote_thumbnail(
@@ -78,6 +79,7 @@ class ThumbnailResource(Resource):
                     request, server_name, media_id,
                     width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(server_name, media_id)
 
     @defer.inlineCallbacks
     def _respond_local_thumbnail(self, request, media_id, width, height,