diff --git a/changelog.d/7947.misc b/changelog.d/7947.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/7947.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 9a847130c0..20ddb9550b 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,9 @@
import logging
import os
import urllib
+from typing import Awaitable
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up.
"""
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
- consumer (IConsumer)
+ consumer: The consumer to stream into.
Returns:
- Deferred: Resolves once the response has finished being written
+ Resolves once the response has finished being written
"""
pass
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..6fb4039e98 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
-from typing import Dict, Tuple
+from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
+ Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
- self, media_type, upload_name, content, content_length, auth_user
- ):
+ self,
+ media_type: str,
+ upload_name: str,
+ content: IO,
+ content_length: int,
+ auth_user: str,
+ ) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type(str): The content type of the file
- upload_name(str): The name of the file
+ media_type: The content type of the file
+ upload_name: The name of the file
content: A file like object that is the content to store
- content_length(int): The length of the content
- auth_user(str): The user_id of the uploader
+ content_length: The length of the content
+ auth_user: The user_id of the uploader
Returns:
- Deferred[str]: The mxc url of the stored content
+ The mxc url of the stored content
"""
media_id = random_string(24)
@@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
- async def get_local_media(self, request, media_id, name):
+ async def get_local_media(
+ self, request: Request, media_id: str, name: Optional[str]
+ ) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
- request(twisted.web.http.Request)
- media_id (str): The media ID of the content. (This is the same as
+ request: The incoming request.
+ media_id: The media ID of the content. (This is the same as
the file_id for local content.)
- name (str|None): Optional name that, if specified, will be used as
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
- async def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(
+ self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ ) -> None:
"""Respond to requests for remote media.
Args:
- request(twisted.web.http.Request)
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
- name (str|None): Optional name that, if specified, will be used as
+ request: The incoming request.
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@@ -245,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
- async def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
Returns:
- Deferred[dict]: The media_info of the file
+ The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info
- async def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(
+ self, server_name: str, media_id: str
+ ) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -288,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
- Deferred[(Responder, media_info)]
+ A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(
+ self, server_name: str, media_id: str, file_id: str
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
- server_name (str): Originating server
- media_id (str): The media ID of the content (as defined by the
+ server_name: Originating server
+ media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
- file_id (str): Local file ID
+ file_id: Local file ID
Returns:
- Deferred[MediaInfo]
+ The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
- self, server_name, media_id, file_id, media_type, url_cache=False
- ):
+ self,
+ server_name: Optional[str],
+ media_id: str,
+ file_id: str,
+ media_type: str,
+ url_cache: bool = False,
+ ) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
- server_name (str|None): The server name if remote media, else None if local
- media_id (str): The media ID of the content. (This is the same as
+ server_name: The server name if remote media, else None if local
+ media_id: The media ID of the content. (This is the same as
the file_id for local content)
- file_id (str): Local file ID
- media_type (str): The content type of the file
- url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ file_id: Local file ID
+ media_type: The content type of the file
+ url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
- Deferred[dict]: Dict with "width" and "height" keys of original image
+ Dict with "width" and "height" keys of original image or None if the
+ media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
- return
+ return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@@ -584,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 66bc1c3360..858b6d3005 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
import inspect
import logging
import os
import shutil
-from typing import Optional
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender
@@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+ from .storage_provider import StorageProvider
logger = logging.getLogger(__name__)
@@ -34,20 +39,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
- hs (synapse.server.Homeserver)
- local_media_directory (str): Base path where we store media on disk
- filepaths (MediaFilePaths)
- storage_providers ([StorageProvider]): List of StorageProvider that are
- used to fetch and store files.
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
"""
- def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProvider"],
+ ):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
- async def store_file(self, source, file_info: FileInfo) -> str:
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
@@ -69,7 +79,7 @@ class MediaStorage(object):
return fname
@contextlib.contextmanager
- def store_into_file(self, file_info):
+ def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -85,7 +95,7 @@ class MediaStorage(object):
error.
Args:
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Example:
@@ -143,9 +153,9 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = provider.fetch(path, file_info)
- # Fetch is supposed to return an Awaitable, but guard against
- # improper implementations.
+ res = provider.fetch(path, file_info) # type: Any
+ # Fetch is supposed to return an Awaitable[Responder], but guard
+ # against improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
@@ -174,9 +184,9 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = provider.fetch(path, file_info)
- # Fetch is supposed to return an Awaitable, but guard against
- # improper implementations.
+ res = provider.fetch(path, file_info) # type: Any
+ # Fetch is supposed to return an Awaitable[Responder], but guard
+ # against improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
@@ -190,17 +200,11 @@ class MediaStorage(object):
raise Exception("file could not be found")
- def _file_info_to_path(self, file_info):
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
-
- Args:
- file_info (FileInfo)
-
- Returns:
- str
"""
if file_info.url_cache:
if file_info.thumbnail:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 13d1a6d2ed..e12f65a206 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
- async def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
- url (str):
- user (str):
- ts (int):
+ url: The URL to preview.
+ user: The user requesting the preview.
+ ts: The timestamp requested for the preview.
Returns:
- Deferred[bytes]: json-encoded og data
+ json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..a33f56e806 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -16,62 +16,62 @@
import logging
import os
import shutil
-
-from twisted.internet import defer
+from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
-class StorageProvider(object):
+class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
-
- Returns:
- Deferred
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
"""
- pass
- def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
Returns:
- Deferred(Responder): Returns a Responder if the provider has the file,
- otherwise returns None.
+ Returns a Responder if the provider has the file, otherwise returns None.
"""
- pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
- backend (StorageProvider)
- store_local (bool): Whether to store new local files or not.
- store_synchronous (bool): Whether to wait for file to be successfully
+ backend: The storage provider to wrap.
+ store_local: Whether to store new local files or not.
+ store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
- store_remote (bool): Whether remote media should be uploaded
+ store_remote: Whether remote media should be uploaded
"""
- def __init__(self, backend, store_local, store_synchronous, store_remote):
+ def __init__(
+ self,
+ backend: StorageProvider,
+ store_local: bool,
+ store_synchronous: bool,
+ store_remote: bool,
+ ):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
- return defer.succeed(None)
+ return None
if file_info.server_name and not self.store_remote:
- return defer.succeed(None)
+ return None
if self.store_synchronous:
- return self.backend.store_file(path, file_info)
+ return await self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
def store():
@@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
- return defer.succeed(None)
+ return None
- def fetch(self, path, file_info):
- return self.backend.fetch(path, file_info)
+ async def fetch(self, path, file_info):
+ return await self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return defer_to_thread(
+ return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- def fetch(self, path, file_info):
+ async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
|