summary refs log tree commit diff
path: root/synapse/rest/media/v1
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-07-30 19:00:29 +0100
committerBrendan Abolivier <babolivier@matrix.org>2020-07-30 19:00:29 +0100
commit69158e554f30ac8b6b646a62fa496a2c0005dea6 (patch)
tree42fdb177abede9c0128906d4e6661cde0ee9cd6c /synapse/rest/media/v1
parentChangelog (diff)
parentUpdate workers docs (#7990) (diff)
downloadsynapse-69158e554f30ac8b6b646a62fa496a2c0005dea6.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/new_push_rules
Diffstat (limited to 'synapse/rest/media/v1')
-rw-r--r--synapse/rest/media/v1/_base.py23
-rw-r--r--synapse/rest/media/v1/media_repository.py105
-rw-r--r--synapse/rest/media/v1/media_storage.py96
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py275
-rw-r--r--synapse/rest/media/v1/storage_provider.py62
5 files changed, 384 insertions, 177 deletions
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 595849f9d5..20ddb9550b 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,8 +17,9 @@
 import logging
 import os
 import urllib
+from typing import Awaitable
 
-from twisted.internet import defer
+from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.api.errors import Codes, SynapseError, cs_error
@@ -77,8 +78,9 @@ def respond_404(request):
     )
 
 
-@defer.inlineCallbacks
-def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
+async def respond_with_file(
+    request, media_type, file_path, file_size=None, upload_name=None
+):
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
@@ -89,7 +91,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
         add_file_headers(request, media_type, file_size, upload_name)
 
         with open(file_path, "rb") as f:
-            yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+            await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
 
         finish_request(request)
     else:
@@ -198,8 +200,9 @@ def _can_encode_filename_as_token(x):
     return True
 
 
-@defer.inlineCallbacks
-def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+async def respond_with_responder(
+    request, responder, media_type, file_size, upload_name=None
+):
     """Responds to the request with given responder. If responder is None then
     returns 404.
 
@@ -218,7 +221,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
     add_file_headers(request, media_type, file_size, upload_name)
     try:
         with responder:
-            yield responder.write_to_consumer(request)
+            await responder.write_to_consumer(request)
     except Exception as e:
         # The majority of the time this will be due to the client having gone
         # away. Unfortunately, Twisted simply throws a generic exception at us
@@ -239,14 +242,14 @@ class Responder(object):
     held can be cleaned up.
     """
 
-    def write_to_consumer(self, consumer):
+    def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
         """Stream response into consumer
 
         Args:
-            consumer (IConsumer)
+            consumer: The consumer to stream into.
 
         Returns:
-            Deferred: Resolves once the response has finished being written
+            Resolves once the response has finished being written
         """
         pass
 
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..6fb4039e98 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno
 import logging
 import os
 import shutil
-from typing import Dict, Tuple
+from typing import IO, Dict, Optional, Tuple
 
 import twisted.internet.error
 import twisted.web.http
+from twisted.web.http import Request
 from twisted.web.resource import Resource
 
 from synapse.api.errors import (
@@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
 
 from ._base import (
     FileInfo,
+    Responder,
     get_filename_from_headers,
     respond_404,
     respond_with_responder,
@@ -135,19 +137,24 @@ class MediaRepository(object):
             self.recently_accessed_locals.add(media_id)
 
     async def create_content(
-        self, media_type, upload_name, content, content_length, auth_user
-    ):
+        self,
+        media_type: str,
+        upload_name: str,
+        content: IO,
+        content_length: int,
+        auth_user: str,
+    ) -> str:
         """Store uploaded content for a local user and return the mxc URL
 
         Args:
-            media_type(str): The content type of the file
-            upload_name(str): The name of the file
+            media_type: The content type of the file
+            upload_name: The name of the file
             content: A file like object that is the content to store
-            content_length(int): The length of the content
-            auth_user(str): The user_id of the uploader
+            content_length: The length of the content
+            auth_user: The user_id of the uploader
 
         Returns:
-            Deferred[str]: The mxc url of the stored content
+            The mxc url of the stored content
         """
         media_id = random_string(24)
 
@@ -170,19 +177,20 @@ class MediaRepository(object):
 
         return "mxc://%s/%s" % (self.server_name, media_id)
 
-    async def get_local_media(self, request, media_id, name):
+    async def get_local_media(
+        self, request: Request, media_id: str, name: Optional[str]
+    ) -> None:
         """Responds to reqests for local media, if exists, or returns 404.
 
         Args:
-            request(twisted.web.http.Request)
-            media_id (str): The media ID of the content. (This is the same as
+            request: The incoming request.
+            media_id: The media ID of the content. (This is the same as
                 the file_id for local content.)
-            name (str|None): Optional name that, if specified, will be used as
+            name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
 
         Returns:
-            Deferred: Resolves once a response has successfully been written
-                to request
+            Resolves once a response has successfully been written to request
         """
         media_info = await self.store.get_local_media(media_id)
         if not media_info or media_info["quarantined_by"]:
@@ -203,20 +211,20 @@ class MediaRepository(object):
             request, responder, media_type, media_length, upload_name
         )
 
-    async def get_remote_media(self, request, server_name, media_id, name):
+    async def get_remote_media(
+        self, request: Request, server_name: str, media_id: str, name: Optional[str]
+    ) -> None:
         """Respond to requests for remote media.
 
         Args:
-            request(twisted.web.http.Request)
-            server_name (str): Remote server_name where the media originated.
-            media_id (str): The media ID of the content (as defined by the
-                remote server).
-            name (str|None): Optional name that, if specified, will be used as
+            request: The incoming request.
+            server_name: Remote server_name where the media originated.
+            media_id: The media ID of the content (as defined by the remote server).
+            name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
 
         Returns:
-            Deferred: Resolves once a response has successfully been written
-                to request
+            Resolves once a response has successfully been written to request
         """
         if (
             self.federation_domain_whitelist is not None
@@ -245,17 +253,16 @@ class MediaRepository(object):
         else:
             respond_404(request)
 
-    async def get_remote_media_info(self, server_name, media_id):
+    async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
         """Gets the media info associated with the remote file, downloading
         if necessary.
 
         Args:
-            server_name (str): Remote server_name where the media originated.
-            media_id (str): The media ID of the content (as defined by the
-                remote server).
+            server_name: Remote server_name where the media originated.
+            media_id: The media ID of the content (as defined by the remote server).
 
         Returns:
-            Deferred[dict]: The media_info of the file
+            The media info of the file
         """
         if (
             self.federation_domain_whitelist is not None
@@ -278,7 +285,9 @@ class MediaRepository(object):
 
         return media_info
 
-    async def _get_remote_media_impl(self, server_name, media_id):
+    async def _get_remote_media_impl(
+        self, server_name: str, media_id: str
+    ) -> Tuple[Optional[Responder], dict]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
 
@@ -288,7 +297,7 @@ class MediaRepository(object):
                 remote server).
 
         Returns:
-            Deferred[(Responder, media_info)]
+            A tuple of responder and the media info of the file.
         """
         media_info = await self.store.get_cached_remote_media(server_name, media_id)
 
@@ -319,19 +328,21 @@ class MediaRepository(object):
         responder = await self.media_storage.fetch_media(file_info)
         return responder, media_info
 
-    async def _download_remote_file(self, server_name, media_id, file_id):
+    async def _download_remote_file(
+        self, server_name: str, media_id: str, file_id: str
+    ) -> dict:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
         Args:
-            server_name (str): Originating server
-            media_id (str): The media ID of the content (as defined by the
+            server_name: Originating server
+            media_id: The media ID of the content (as defined by the
                 remote server). This is different than the file_id, which is
                 locally generated.
-            file_id (str): Local file ID
+            file_id: Local file ID
 
         Returns:
-            Deferred[MediaInfo]
+            The media info of the file.
         """
 
         file_info = FileInfo(server_name=server_name, file_id=file_id)
@@ -549,25 +560,31 @@ class MediaRepository(object):
             return output_path
 
     async def _generate_thumbnails(
-        self, server_name, media_id, file_id, media_type, url_cache=False
-    ):
+        self,
+        server_name: Optional[str],
+        media_id: str,
+        file_id: str,
+        media_type: str,
+        url_cache: bool = False,
+    ) -> Optional[dict]:
         """Generate and store thumbnails for an image.
 
         Args:
-            server_name (str|None): The server name if remote media, else None if local
-            media_id (str): The media ID of the content. (This is the same as
+            server_name: The server name if remote media, else None if local
+            media_id: The media ID of the content. (This is the same as
                 the file_id for local content)
-            file_id (str): Local file ID
-            media_type (str): The content type of the file
-            url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+            file_id: Local file ID
+            media_type: The content type of the file
+            url_cache: If we are thumbnailing images downloaded for the URL cache,
                 used exclusively by the url previewer
 
         Returns:
-            Deferred[dict]: Dict with "width" and "height" keys of original image
+            Dict with "width" and "height" keys of original image or None if the
+            media cannot be thumbnailed.
         """
         requirements = self._get_thumbnail_requirements(media_type)
         if not requirements:
-            return
+            return None
 
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(server_name, file_id, url_cache=url_cache)
@@ -584,7 +601,7 @@ class MediaRepository(object):
                 m_height,
                 self.max_image_pixels,
             )
-            return
+            return None
 
         if thumbnailer.transpose_method is not None:
             m_width, m_height = await defer_to_thread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 79cb0dddbe..858b6d3005 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,19 +12,25 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import contextlib
+import inspect
 import logging
 import os
 import shutil
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
 
-from twisted.internet import defer
 from twisted.protocols.basic import FileSender
 
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
 from synapse.util.file_consumer import BackgroundFileConsumer
 
-from ._base import Responder
+from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+    from .storage_provider import StorageProvider
 
 logger = logging.getLogger(__name__)
 
@@ -33,49 +39,53 @@ class MediaStorage(object):
     """Responsible for storing/fetching files from local sources.
 
     Args:
-        hs (synapse.server.Homeserver)
-        local_media_directory (str): Base path where we store media on disk
-        filepaths (MediaFilePaths)
-        storage_providers ([StorageProvider]): List of StorageProvider that are
-            used to fetch and store files.
+        hs
+        local_media_directory: Base path where we store media on disk
+        filepaths
+        storage_providers: List of StorageProvider that are used to fetch and store files.
     """
 
-    def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        local_media_directory: str,
+        filepaths: MediaFilePaths,
+        storage_providers: Sequence["StorageProvider"],
+    ):
         self.hs = hs
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
 
-    @defer.inlineCallbacks
-    def store_file(self, source, file_info):
+    async def store_file(self, source: IO, file_info: FileInfo) -> str:
         """Write `source` to the on disk media store, and also any other
         configured storage providers
 
         Args:
             source: A file like object that should be written
-            file_info (FileInfo): Info about the file to store
+            file_info: Info about the file to store
 
         Returns:
-            Deferred[str]: the file path written to in the primary media store
+            the file path written to in the primary media store
         """
 
         with self.store_into_file(file_info) as (f, fname, finish_cb):
             # Write to the main repository
-            yield defer_to_thread(
+            await defer_to_thread(
                 self.hs.get_reactor(), _write_file_synchronously, source, f
             )
-            yield finish_cb()
+            await finish_cb()
 
         return fname
 
     @contextlib.contextmanager
-    def store_into_file(self, file_info):
+    def store_into_file(self, file_info: FileInfo):
         """Context manager used to get a file like object to write into, as
         described by file_info.
 
         Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
         like object that can be written to, fname is the absolute path of file
-        on disk, and finish_cb is a function that returns a Deferred.
+        on disk, and finish_cb is a function that returns an awaitable.
 
         fname can be used to read the contents from after upload, e.g. to
         generate thumbnails.
@@ -85,13 +95,13 @@ class MediaStorage(object):
         error.
 
         Args:
-            file_info (FileInfo): Info about the file to store
+            file_info: Info about the file to store
 
         Example:
 
             with media_storage.store_into_file(info) as (f, fname, finish_cb):
                 # .. write into f ...
-                yield finish_cb()
+                await finish_cb()
         """
 
         path = self._file_info_to_path(file_info)
@@ -103,10 +113,13 @@ class MediaStorage(object):
 
         finished_called = [False]
 
-        @defer.inlineCallbacks
-        def finish():
+        async def finish():
             for provider in self.storage_providers:
-                yield provider.store_file(path, file_info)
+                # store_file is supposed to return an Awaitable, but guard
+                # against improper implementations.
+                result = provider.store_file(path, file_info)
+                if inspect.isawaitable(result):
+                    await result
 
             finished_called[0] = True
 
@@ -123,17 +136,15 @@ class MediaStorage(object):
         if not finished_called:
             raise Exception("Finished callback not called")
 
-    @defer.inlineCallbacks
-    def fetch_media(self, file_info):
+    async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
         and configured storage providers.
 
         Args:
-            file_info (FileInfo)
+            file_info
 
         Returns:
-            Deferred[Responder|None]: Returns a Responder if the file was found,
-                otherwise None.
+            Returns a Responder if the file was found, otherwise None.
         """
 
         path = self._file_info_to_path(file_info)
@@ -142,23 +153,26 @@ class MediaStorage(object):
             return FileResponder(open(local_path, "rb"))
 
         for provider in self.storage_providers:
-            res = yield provider.fetch(path, file_info)
+            res = provider.fetch(path, file_info)  # type: Any
+            # Fetch is supposed to return an Awaitable[Responder], but guard
+            # against improper implementations.
+            if inspect.isawaitable(res):
+                res = await res
             if res:
                 logger.debug("Streaming %s from %s", path, provider)
                 return res
 
         return None
 
-    @defer.inlineCallbacks
-    def ensure_media_is_in_local_cache(self, file_info):
+    async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
         """Ensures that the given file is in the local cache. Attempts to
         download it from storage providers if it isn't.
 
         Args:
-            file_info (FileInfo)
+            file_info
 
         Returns:
-            Deferred[str]: Full path to local file
+            Full path to local file
         """
         path = self._file_info_to_path(file_info)
         local_path = os.path.join(self.local_media_directory, path)
@@ -170,29 +184,27 @@ class MediaStorage(object):
             os.makedirs(dirname)
 
         for provider in self.storage_providers:
-            res = yield provider.fetch(path, file_info)
+            res = provider.fetch(path, file_info)  # type: Any
+            # Fetch is supposed to return an Awaitable[Responder], but guard
+            # against improper implementations.
+            if inspect.isawaitable(res):
+                res = await res
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
                         open(local_path, "wb"), self.hs.get_reactor()
                     )
-                    yield res.write_to_consumer(consumer)
-                    yield consumer.wait()
+                    await res.write_to_consumer(consumer)
+                    await consumer.wait()
                 return local_path
 
         raise Exception("file could not be found")
 
-    def _file_info_to_path(self, file_info):
+    def _file_info_to_path(self, file_info: FileInfo) -> str:
         """Converts file_info into a relative path.
 
         The path is suitable for storing files under a directory, e.g. used to
         store files on local FS under the base media repository directory.
-
-        Args:
-            file_info (FileInfo)
-
-        Returns:
-            str
         """
         if file_info.url_cache:
             if file_info.thumbnail:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e52c86c798..e12f65a206 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -26,6 +26,7 @@ import traceback
 from typing import Dict, Optional
 from urllib import parse as urlparse
 
+import attr
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
 OG_TAG_NAME_MAXLEN = 50
 OG_TAG_VALUE_MAXLEN = 1000
 
+ONE_HOUR = 60 * 60 * 1000
+
+# A map of globs to API endpoints.
+_oembed_globs = {
+    # Twitter.
+    "https://publish.twitter.com/oembed": [
+        "https://twitter.com/*/status/*",
+        "https://*.twitter.com/*/status/*",
+        "https://twitter.com/*/moments/*",
+        "https://*.twitter.com/*/moments/*",
+        # Include the HTTP versions too.
+        "http://twitter.com/*/status/*",
+        "http://*.twitter.com/*/status/*",
+        "http://twitter.com/*/moments/*",
+        "http://*.twitter.com/*/moments/*",
+    ],
+}
+# Convert the globs to regular expressions.
+_oembed_patterns = {}
+for endpoint, globs in _oembed_globs.items():
+    for glob in globs:
+        # Convert the glob into a sane regular expression to match against. The
+        # rules followed will be slightly different for the domain portion vs.
+        # the rest.
+        #
+        # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+        # 2. The domain can have globs, but we limit it to characters that can
+        #    reasonably be a domain part.
+        #    TODO: This does not attempt to handle Unicode domain names.
+        # 3. Other parts allow a glob to be any one, or more, characters.
+        results = urlparse.urlparse(glob)
+
+        # Ensure the scheme does not have wildcards (and is a sane scheme).
+        if results.scheme not in {"http", "https"}:
+            raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
+
+        pattern = urlparse.urlunparse(
+            [
+                results.scheme,
+                re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+            ]
+            + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+        )
+        _oembed_patterns[re.compile(pattern)] = endpoint
+
+
+@attr.s
+class OEmbedResult:
+    # Either HTML content or URL must be provided.
+    html = attr.ib(type=Optional[str])
+    url = attr.ib(type=Optional[str])
+    title = attr.ib(type=Optional[str])
+    # Number of seconds to cache the content.
+    cache_age = attr.ib(type=int)
+
+
+class OEmbedError(Exception):
+    """An error occurred processing the oEmbed object."""
+
 
 class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
@@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             cache_name="url_previews",
             clock=self.clock,
             # don't spider URLs more often than once an hour
-            expiry_ms=60 * 60 * 1000,
+            expiry_ms=ONE_HOUR,
         )
 
         if self._worker_run_media_background_jobs:
@@ -171,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
         og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
         respond_with_json_bytes(request, 200, og, send_cors=True)
 
-    async def _do_preview(self, url, user, ts):
+    async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
         """Check the db, and download the URL and build a preview
 
         Args:
-            url (str):
-            user (str):
-            ts (int):
+            url: The URL to preview.
+            user: The user requesting the preview.
+            ts: The timestamp requested for the preview.
 
         Returns:
-            Deferred[bytes]: json-encoded og data
+            json-encoded og data
         """
         # check the URL cache in the DB (which will also provide us with
         # historical previews, if we have any)
@@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         return jsonog.encode("utf8")
 
+    def _get_oembed_url(self, url: str) -> Optional[str]:
+        """
+        Check whether the URL should be downloaded as oEmbed content instead.
+
+        Params:
+            url: The URL to check.
+
+        Returns:
+            A URL to use instead or None if the original URL should be used.
+        """
+        for url_pattern, endpoint in _oembed_patterns.items():
+            if url_pattern.fullmatch(url):
+                return endpoint
+
+        # No match.
+        return None
+
+    async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+        """
+        Request content from an oEmbed endpoint.
+
+        Params:
+            endpoint: The oEmbed API endpoint.
+            url: The URL to pass to the API.
+
+        Returns:
+            An object representing the metadata returned.
+
+        Raises:
+            OEmbedError if fetching or parsing of the oEmbed information fails.
+        """
+        try:
+            logger.debug("Trying to get oEmbed content for url '%s'", url)
+            result = await self.client.get_json(
+                endpoint,
+                # TODO Specify max height / width.
+                # Note that only the JSON format is supported.
+                args={"url": url},
+            )
+
+            # Ensure there's a version of 1.0.
+            if result.get("version") != "1.0":
+                raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+            oembed_type = result.get("type")
+
+            # Ensure the cache age is None or an int.
+            cache_age = result.get("cache_age")
+            if cache_age:
+                cache_age = int(cache_age)
+
+            oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+            # HTML content.
+            if oembed_type == "rich":
+                oembed_result.html = result.get("html")
+                return oembed_result
+
+            if oembed_type == "photo":
+                oembed_result.url = result.get("url")
+                return oembed_result
+
+            # TODO Handle link and video types.
+
+            if "thumbnail_url" in result:
+                oembed_result.url = result.get("thumbnail_url")
+                return oembed_result
+
+            raise OEmbedError("Incompatible oEmbed information.")
+
+        except OEmbedError as e:
+            # Trap OEmbedErrors first so we can directly re-raise them.
+            logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+            raise
+
+        except Exception as e:
+            # Trap any exception and let the code follow as usual.
+            # FIXME: pass through 404s and other error messages nicely
+            logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+            raise OEmbedError() from e
+
     async def _download_url(self, url, user):
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
@@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        # If this URL can be accessed via oEmbed, use that instead.
+        url_to_download = url
+        oembed_url = self._get_oembed_url(url)
+        if oembed_url:
+            # The result might be a new URL to download, or it might be HTML content.
             try:
-                logger.debug("Trying to get preview for url '%s'", url)
-                length, headers, uri, code = await self.client.get_file(
-                    url,
-                    output_stream=f,
-                    max_size=self.max_spider_size,
-                    headers={"Accept-Language": self.url_preview_accept_language},
-                )
-            except SynapseError:
-                # Pass SynapseErrors through directly, so that the servlet
-                # handler will return a SynapseError to the client instead of
-                # blank data or a 500.
-                raise
-            except DNSLookupError:
-                # DNS lookup returned no results
-                # Note: This will also be the case if one of the resolved IP
-                # addresses is blacklisted
-                raise SynapseError(
-                    502,
-                    "DNS resolution failure during URL preview generation",
-                    Codes.UNKNOWN,
-                )
-            except Exception as e:
-                # FIXME: pass through 404s and other error messages nicely
-                logger.warning("Error downloading %s: %r", url, e)
+                oembed_result = await self._get_oembed_content(oembed_url, url)
+                if oembed_result.url:
+                    url_to_download = oembed_result.url
+                elif oembed_result.html:
+                    url_to_download = None
+            except OEmbedError:
+                # If an error occurs, try doing a normal preview.
+                pass
 
-                raise SynapseError(
-                    500,
-                    "Failed to download content: %s"
-                    % (traceback.format_exception_only(sys.exc_info()[0], e),),
-                    Codes.UNKNOWN,
-                )
-            await finish()
+        if url_to_download:
+            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+                try:
+                    logger.debug("Trying to get preview for url '%s'", url_to_download)
+                    length, headers, uri, code = await self.client.get_file(
+                        url_to_download,
+                        output_stream=f,
+                        max_size=self.max_spider_size,
+                        headers={"Accept-Language": self.url_preview_accept_language},
+                    )
+                except SynapseError:
+                    # Pass SynapseErrors through directly, so that the servlet
+                    # handler will return a SynapseError to the client instead of
+                    # blank data or a 500.
+                    raise
+                except DNSLookupError:
+                    # DNS lookup returned no results
+                    # Note: This will also be the case if one of the resolved IP
+                    # addresses is blacklisted
+                    raise SynapseError(
+                        502,
+                        "DNS resolution failure during URL preview generation",
+                        Codes.UNKNOWN,
+                    )
+                except Exception as e:
+                    # FIXME: pass through 404s and other error messages nicely
+                    logger.warning("Error downloading %s: %r", url_to_download, e)
+
+                    raise SynapseError(
+                        500,
+                        "Failed to download content: %s"
+                        % (traceback.format_exception_only(sys.exc_info()[0], e),),
+                        Codes.UNKNOWN,
+                    )
+                await finish()
+
+                if b"Content-Type" in headers:
+                    media_type = headers[b"Content-Type"][0].decode("ascii")
+                else:
+                    media_type = "application/octet-stream"
+
+                download_name = get_filename_from_headers(headers)
+
+                # FIXME: we should calculate a proper expiration based on the
+                # Cache-Control and Expire headers.  But for now, assume 1 hour.
+                expires = ONE_HOUR
+                etag = headers["ETag"][0] if "ETag" in headers else None
+        else:
+            html_bytes = oembed_result.html.encode("utf-8")  # type: ignore
+            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+                f.write(html_bytes)
+                await finish()
+
+            media_type = "text/html"
+            download_name = oembed_result.title
+            length = len(html_bytes)
+            # If a specific cache age was not given, assume 1 hour.
+            expires = oembed_result.cache_age or ONE_HOUR
+            uri = oembed_url
+            code = 200
+            etag = None
 
         try:
-            if b"Content-Type" in headers:
-                media_type = headers[b"Content-Type"][0].decode("ascii")
-            else:
-                media_type = "application/octet-stream"
             time_now_ms = self.clock.time_msec()
 
-            download_name = get_filename_from_headers(headers)
-
             await self.store.store_local_media(
                 media_id=file_id,
                 media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
+                time_now_ms=time_now_ms,
                 upload_name=download_name,
                 media_length=length,
                 user_id=user,
@@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource):
             "filename": fname,
             "uri": uri,
             "response_code": code,
-            # FIXME: we should calculate a proper expiration based on the
-            # Cache-Control and Expire headers.  But for now, assume 1 hour.
-            "expires": 60 * 60 * 1000,
-            "etag": headers["ETag"][0] if "ETag" in headers else None,
+            "expires": expires,
+            "etag": etag,
         }
 
     def _start_expire_url_cache_data(self):
@@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         # These may be cached for a bit on the client (i.e., they
         # may have a room open with a preview url thing open).
         # So we wait a couple of days before deleting, just in case.
-        expire_before = now - 2 * 24 * 60 * 60 * 1000
+        expire_before = now - 2 * 24 * ONE_HOUR
         media_ids = await self.store.get_url_cache_media_before(expire_before)
 
         removed_media = []
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..a33f56e806 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -16,62 +16,62 @@
 import logging
 import os
 import shutil
-
-from twisted.internet import defer
+from typing import Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
 
+from ._base import FileInfo, Responder
 from .media_storage import FileResponder
 
 logger = logging.getLogger(__name__)
 
 
-class StorageProvider(object):
+class StorageProvider:
     """A storage provider is a service that can store uploaded media and
     retrieve them.
     """
 
-    def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo):
         """Store the file described by file_info. The actual contents can be
         retrieved by reading the file in file_info.upload_path.
 
         Args:
-            path (str): Relative path of file in local cache
-            file_info (FileInfo)
-
-        Returns:
-            Deferred
+            path: Relative path of file in local cache
+            file_info: The metadata of the file.
         """
-        pass
 
-    def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """Attempt to fetch the file described by file_info and stream it
         into writer.
 
         Args:
-            path (str): Relative path of file in local cache
-            file_info (FileInfo)
+            path: Relative path of file in local cache
+            file_info: The metadata of the file.
 
         Returns:
-            Deferred(Responder): Returns a Responder if the provider has the file,
-                otherwise returns None.
+            Returns a Responder if the provider has the file, otherwise returns None.
         """
-        pass
 
 
 class StorageProviderWrapper(StorageProvider):
     """Wraps a storage provider and provides various config options
 
     Args:
-        backend (StorageProvider)
-        store_local (bool): Whether to store new local files or not.
-        store_synchronous (bool): Whether to wait for file to be successfully
+        backend: The storage provider to wrap.
+        store_local: Whether to store new local files or not.
+        store_synchronous: Whether to wait for file to be successfully
             uploaded, or todo the upload in the background.
-        store_remote (bool): Whether remote media should be uploaded
+        store_remote: Whether remote media should be uploaded
     """
 
-    def __init__(self, backend, store_local, store_synchronous, store_remote):
+    def __init__(
+        self,
+        backend: StorageProvider,
+        store_local: bool,
+        store_synchronous: bool,
+        store_remote: bool,
+    ):
         self.backend = backend
         self.store_local = store_local
         self.store_synchronous = store_synchronous
@@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
     def __str__(self):
         return "StorageProviderWrapper[%s]" % (self.backend,)
 
-    def store_file(self, path, file_info):
+    async def store_file(self, path, file_info):
         if not file_info.server_name and not self.store_local:
-            return defer.succeed(None)
+            return None
 
         if file_info.server_name and not self.store_remote:
-            return defer.succeed(None)
+            return None
 
         if self.store_synchronous:
-            return self.backend.store_file(path, file_info)
+            return await self.backend.store_file(path, file_info)
         else:
             # TODO: Handle errors.
             def store():
@@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
                     logger.exception("Error storing file")
 
             run_in_background(store)
-            return defer.succeed(None)
+            return None
 
-    def fetch(self, path, file_info):
-        return self.backend.fetch(path, file_info)
+    async def fetch(self, path, file_info):
+        return await self.backend.fetch(path, file_info)
 
 
 class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
     def __str__(self):
         return "FileStorageProviderBackend[%s]" % (self.base_directory,)
 
-    def store_file(self, path, file_info):
+    async def store_file(self, path, file_info):
         """See StorageProvider.store_file"""
 
         primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
         if not os.path.exists(dirname):
             os.makedirs(dirname)
 
-        return defer_to_thread(
+        return await defer_to_thread(
             self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
         )
 
-    def fetch(self, path, file_info):
+    async def fetch(self, path, file_info):
         """See StorageProvider.fetch"""
 
         backup_fname = os.path.join(self.base_directory, path)