diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b8c95d045a..a8364d9793 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ purge = content.get("purge", True)
+ if not isinstance(purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet):
)
# Purge room
- await self.pagination_handler.purge_room(room_id)
+ if purge:
+ await self.pagination_handler.purge_room(room_id)
return (200, ret)
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index b21538766d..f016b4f1bd 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,8 +17,7 @@
"""
import logging
import re
-
-from twisted.internet import defer
+from typing import Iterable, Pattern
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -27,15 +26,23 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
-def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
+def client_patterns(
+ path_regex: str,
+ releases: Iterable[int] = (0,),
+ unstable: bool = True,
+ v1: bool = False,
+) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
- path_regex (str): The regex string to match. This should NOT have a ^
+ path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
+ releases: An iterable of releases to include this endpoint under.
+ unstable: If true, include this endpoint under the "unstable" prefix.
+ v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
- SRE_Pattern
+ An iterable of patterns.
"""
patterns = []
@@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
- Takes a on_POST method which returns a deferred (errcode, body) response
+ Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
# ...
- yield self.auth_handler.check_auth
- """
+ await self.auth_handler.check_auth
+ """
- def wrapped(*args, **kwargs):
- res = defer.ensureDeferred(orig(*args, **kwargs))
- res.addErrback(_catch_incomplete_interactive_auth)
- return res
+ async def wrapped(*args, **kwargs):
+ try:
+ return await orig(*args, **kwargs)
+ except InteractiveAuthIncompleteError as e:
+ return 401, e.result
return wrapped
-
-
-def _catch_incomplete_interactive_auth(f):
- """helper for interactive_auth_handler
-
- Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
-
- Args:
- f (failure.Failure):
- """
- f.trap(InteractiveAuthIncompleteError)
- return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..3f5bf75e59 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 595849f9d5..20ddb9550b 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,8 +17,9 @@
import logging
import os
import urllib
+from typing import Awaitable
-from twisted.internet import defer
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -77,8 +78,9 @@ def respond_404(request):
)
-@defer.inlineCallbacks
-def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
+async def respond_with_file(
+ request, media_type, file_path, file_size=None, upload_name=None
+):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -89,7 +91,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+ await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@@ -198,8 +200,9 @@ def _can_encode_filename_as_token(x):
return True
-@defer.inlineCallbacks
-def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+async def respond_with_responder(
+ request, responder, media_type, file_size, upload_name=None
+):
"""Responds to the request with given responder. If responder is None then
returns 404.
@@ -218,7 +221,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
- yield responder.write_to_consumer(request)
+ await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us
@@ -239,14 +242,14 @@ class Responder(object):
held can be cleaned up.
"""
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
- consumer (IConsumer)
+ consumer: The consumer to stream into.
Returns:
- Deferred: Resolves once the response has finished being written
+ Resolves once the response has finished being written
"""
pass
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..6fb4039e98 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
-from typing import Dict, Tuple
+from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
+ Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
- self, media_type, upload_name, content, content_length, auth_user
- ):
+ self,
+ media_type: str,
+ upload_name: str,
+ content: IO,
+ content_length: int,
+ auth_user: str,
+ ) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type(str): The content type of the file
- upload_name(str): The name of the file
+ media_type: The content type of the file
+ upload_name: The name of the file
content: A file like object that is the content to store
- content_length(int): The length of the content
- auth_user(str): The user_id of the uploader
+ content_length: The length of the content
+ auth_user: The user_id of the uploader
Returns:
- Deferred[str]: The mxc url of the stored content
+ The mxc url of the stored content
"""
media_id = random_string(24)
@@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
- async def get_local_media(self, request, media_id, name):
+ async def get_local_media(
+ self, request: Request, media_id: str, name: Optional[str]
+ ) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
- request(twisted.web.http.Request)
- media_id (str): The media ID of the content. (This is the same as
+ request: The incoming request.
+ media_id: The media ID of the content. (This is the same as
the file_id for local content.)
- name (str|None): Optional name that, if specified, will be used as
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
- async def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(
+ self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ ) -> None:
"""Respond to requests for remote media.
Args:
- request(twisted.web.http.Request)
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
- name (str|None): Optional name that, if specified, will be used as
+ request: The incoming request.
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@@ -245,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
- async def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
Returns:
- Deferred[dict]: The media_info of the file
+ The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info
- async def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(
+ self, server_name: str, media_id: str
+ ) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -288,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
- Deferred[(Responder, media_info)]
+ A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(
+ self, server_name: str, media_id: str, file_id: str
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
- server_name (str): Originating server
- media_id (str): The media ID of the content (as defined by the
+ server_name: Originating server
+ media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
- file_id (str): Local file ID
+ file_id: Local file ID
Returns:
- Deferred[MediaInfo]
+ The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
- self, server_name, media_id, file_id, media_type, url_cache=False
- ):
+ self,
+ server_name: Optional[str],
+ media_id: str,
+ file_id: str,
+ media_type: str,
+ url_cache: bool = False,
+ ) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
- server_name (str|None): The server name if remote media, else None if local
- media_id (str): The media ID of the content. (This is the same as
+ server_name: The server name if remote media, else None if local
+ media_id: The media ID of the content. (This is the same as
the file_id for local content)
- file_id (str): Local file ID
- media_type (str): The content type of the file
- url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ file_id: Local file ID
+ media_type: The content type of the file
+ url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
- Deferred[dict]: Dict with "width" and "height" keys of original image
+ Dict with "width" and "height" keys of original image or None if the
+ media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
- return
+ return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@@ -584,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 79cb0dddbe..858b6d3005 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,19 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
+import inspect
import logging
import os
import shutil
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
-from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
-from ._base import Responder
+from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+ from .storage_provider import StorageProvider
logger = logging.getLogger(__name__)
@@ -33,49 +39,53 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
- hs (synapse.server.Homeserver)
- local_media_directory (str): Base path where we store media on disk
- filepaths (MediaFilePaths)
- storage_providers ([StorageProvider]): List of StorageProvider that are
- used to fetch and store files.
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
"""
- def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProvider"],
+ ):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
- @defer.inlineCallbacks
- def store_file(self, source, file_info):
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Returns:
- Deferred[str]: the file path written to in the primary media store
+ the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
- yield finish_cb()
+ await finish_cb()
return fname
@contextlib.contextmanager
- def store_into_file(self, file_info):
+ def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns a Deferred.
+ on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
@@ -85,13 +95,13 @@ class MediaStorage(object):
error.
Args:
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
- yield finish_cb()
+ await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -103,10 +113,13 @@ class MediaStorage(object):
finished_called = [False]
- @defer.inlineCallbacks
- def finish():
+ async def finish():
for provider in self.storage_providers:
- yield provider.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = provider.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ await result
finished_called[0] = True
@@ -123,17 +136,15 @@ class MediaStorage(object):
if not finished_called:
raise Exception("Finished callback not called")
- @defer.inlineCallbacks
- def fetch_media(self, file_info):
+ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[Responder|None]: Returns a Responder if the file was found,
- otherwise None.
+ Returns a Responder if the file was found, otherwise None.
"""
path = self._file_info_to_path(file_info)
@@ -142,23 +153,26 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = provider.fetch(path, file_info) # type: Any
+ # Fetch is supposed to return an Awaitable[Responder], but guard
+ # against improper implementations.
+ if inspect.isawaitable(res):
+ res = await res
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return None
- @defer.inlineCallbacks
- def ensure_media_is_in_local_cache(self, file_info):
+ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[str]: Full path to local file
+ Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
@@ -170,29 +184,27 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = provider.fetch(path, file_info) # type: Any
+ # Fetch is supposed to return an Awaitable[Responder], but guard
+ # against improper implementations.
+ if inspect.isawaitable(res):
+ res = await res
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
- yield res.write_to_consumer(consumer)
- yield consumer.wait()
+ await res.write_to_consumer(consumer)
+ await consumer.wait()
return local_path
raise Exception("file could not be found")
- def _file_info_to_path(self, file_info):
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
-
- Args:
- file_info (FileInfo)
-
- Returns:
- str
"""
if file_info.url_cache:
if file_info.thumbnail:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e52c86c798..e12f65a206 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -26,6 +26,7 @@ import traceback
from typing import Dict, Optional
from urllib import parse as urlparse
+import attr
from canonicaljson import json
from twisted.internet import defer
@@ -56,6 +57,65 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
+ONE_HOUR = 60 * 60 * 1000
+
+# A map of globs to API endpoints.
+_oembed_globs = {
+ # Twitter.
+ "https://publish.twitter.com/oembed": [
+ "https://twitter.com/*/status/*",
+ "https://*.twitter.com/*/status/*",
+ "https://twitter.com/*/moments/*",
+ "https://*.twitter.com/*/moments/*",
+ # Include the HTTP versions too.
+ "http://twitter.com/*/status/*",
+ "http://*.twitter.com/*/status/*",
+ "http://twitter.com/*/moments/*",
+ "http://*.twitter.com/*/moments/*",
+ ],
+}
+# Convert the globs to regular expressions.
+_oembed_patterns = {}
+for endpoint, globs in _oembed_globs.items():
+ for glob in globs:
+ # Convert the glob into a sane regular expression to match against. The
+ # rules followed will be slightly different for the domain portion vs.
+ # the rest.
+ #
+ # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+ # 2. The domain can have globs, but we limit it to characters that can
+ # reasonably be a domain part.
+ # TODO: This does not attempt to handle Unicode domain names.
+ # 3. Other parts allow a glob to be any one, or more, characters.
+ results = urlparse.urlparse(glob)
+
+ # Ensure the scheme does not have wildcards (and is a sane scheme).
+ if results.scheme not in {"http", "https"}:
+ raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
+
+ pattern = urlparse.urlunparse(
+ [
+ results.scheme,
+ re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+ ]
+ + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+ )
+ _oembed_patterns[re.compile(pattern)] = endpoint
+
+
+@attr.s
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html = attr.ib(type=Optional[str])
+ url = attr.ib(type=Optional[str])
+ title = attr.ib(type=Optional[str])
+ # Number of seconds to cache the content.
+ cache_age = attr.ib(type=int)
+
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
@@ -99,7 +159,7 @@ class PreviewUrlResource(DirectServeJsonResource):
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
- expiry_ms=60 * 60 * 1000,
+ expiry_ms=ONE_HOUR,
)
if self._worker_run_media_background_jobs:
@@ -171,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
- async def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
- url (str):
- user (str):
- ts (int):
+ url: The URL to preview.
+ user: The user requesting the preview.
+ ts: The timestamp requested for the preview.
Returns:
- Deferred[bytes]: json-encoded og data
+ json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
@@ -310,6 +370,87 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
+ def _get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Params:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in _oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Params:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+ result = await self.client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ # Note that only the JSON format is supported.
+ args={"url": url},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
+
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -319,54 +460,90 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._get_oembed_url(url)
+ if oembed_url:
+ # The result might be a new URL to download, or it might be HTML content.
try:
- logger.debug("Trying to get preview for url '%s'", url)
- length, headers, uri, code = await self.client.get_file(
- url,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url, e)
+ oembed_result = await self._get_oembed_content(oembed_url, url)
+ if oembed_result.url:
+ url_to_download = oembed_result.url
+ elif oembed_result.html:
+ url_to_download = None
+ except OEmbedError:
+ # If an error occurs, try doing a normal preview.
+ pass
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
+ if url_to_download:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ logger.debug("Trying to get preview for url '%s'", url_to_download)
+ length, headers, uri, code = await self.client.get_file(
+ url_to_download,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url_to_download, e)
+
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
+
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
+
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers["ETag"][0] if "ETag" in headers else None
+ else:
+ html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ f.write(html_bytes)
+ await finish()
+
+ media_type = "text/html"
+ download_name = oembed_result.title
+ length = len(html_bytes)
+ # If a specific cache age was not given, assume 1 hour.
+ expires = oembed_result.cache_age or ONE_HOUR
+ uri = oembed_url
+ code = 200
+ etag = None
try:
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
- download_name = get_filename_from_headers(headers)
-
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=download_name,
media_length=length,
user_id=user,
@@ -389,10 +566,8 @@ class PreviewUrlResource(DirectServeJsonResource):
"filename": fname,
"uri": uri,
"response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
+ "expires": expires,
+ "etag": etag,
}
def _start_expire_url_cache_data(self):
@@ -449,7 +624,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * 60 * 60 * 1000
+ expire_before = now - 2 * 24 * ONE_HOUR
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..a33f56e806 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -16,62 +16,62 @@
import logging
import os
import shutil
-
-from twisted.internet import defer
+from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
-class StorageProvider(object):
+class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
-
- Returns:
- Deferred
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
"""
- pass
- def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
Returns:
- Deferred(Responder): Returns a Responder if the provider has the file,
- otherwise returns None.
+ Returns a Responder if the provider has the file, otherwise returns None.
"""
- pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
- backend (StorageProvider)
- store_local (bool): Whether to store new local files or not.
- store_synchronous (bool): Whether to wait for file to be successfully
+ backend: The storage provider to wrap.
+ store_local: Whether to store new local files or not.
+ store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
- store_remote (bool): Whether remote media should be uploaded
+ store_remote: Whether remote media should be uploaded
"""
- def __init__(self, backend, store_local, store_synchronous, store_remote):
+ def __init__(
+ self,
+ backend: StorageProvider,
+ store_local: bool,
+ store_synchronous: bool,
+ store_remote: bool,
+ ):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
- return defer.succeed(None)
+ return None
if file_info.server_name and not self.store_remote:
- return defer.succeed(None)
+ return None
if self.store_synchronous:
- return self.backend.store_file(path, file_info)
+ return await self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
def store():
@@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
- return defer.succeed(None)
+ return None
- def fetch(self, path, file_info):
- return self.backend.fetch(path, file_info)
+ async def fetch(self, path, file_info):
+ return await self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return defer_to_thread(
+ return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- def fetch(self, path, file_info):
+ async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
|