summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/push_rule.py15
-rw-r--r--synapse/rest/client/v2_alpha/account.py66
-rw-r--r--synapse/rest/media/v1/filepath.py19
-rw-r--r--synapse/rest/media/v1/media_repository.py69
-rw-r--r--synapse/rest/media/v1/media_storage.py28
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py5
-rw-r--r--synapse/rest/media/v1/thumbnailer.py14
7 files changed, 192 insertions, 24 deletions
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e781a3bcf4..ddf8ed5e9c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet):
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
 
     async def set_rule_attr(self, user_id, spec, val):
+        if spec["attr"] not in ("enabled", "actions"):
+            # for the sake of potential future expansion, shouldn't report
+            # 404 in the case of an unknown request so check it corresponds to
+            # a known attribute first.
+            raise UnrecognizedRequestError()
+
+        namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+        rule_id = spec["rule_id"]
+        is_default_rule = rule_id.startswith(".")
+        if is_default_rule:
+            if namespaced_rule_id not in BASE_RULE_IDS:
+                raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
         if spec["attr"] == "enabled":
             if isinstance(val, dict) and "enabled" in val:
                 val = val["enabled"]
@@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet):
                 # This should *actually* take a dict, but many clients pass
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
-            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
             return await self.store.set_push_rule_enabled(
-                user_id, namespaced_rule_id, val
+                user_id, namespaced_rule_id, val, is_default_rule
             )
         elif spec["attr"] == "actions":
             actions = val.get("actions")
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3481477731..455051ac46 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -17,6 +17,11 @@
 import logging
 import random
 from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
@@ -98,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -446,6 +454,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         existing_user_id = await self.store.get_user_id_by_threepid("email", email)
 
         if existing_user_id is not None:
@@ -517,6 +528,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # Raise if the provided next_link value isn't valid
+        assert_valid_next_link(self.hs, next_link)
+
         existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
 
         if existing_user_id is not None:
@@ -603,15 +617,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
 
             # Perform a 302 redirect if next_link is set
             if next_link:
-                if next_link.startswith("file:///"):
-                    logger.warning(
-                        "Not redirecting to next_link as it is a local file: address"
-                    )
-                else:
-                    request.setResponseCode(302)
-                    request.setHeader("Location", next_link)
-                    finish_request(request)
-                    return None
+                request.setResponseCode(302)
+                request.setHeader("Location", next_link)
+                finish_request(request)
+                return None
 
             # Otherwise show the success template
             html = self.config.email_add_threepid_template_success_html_content
@@ -875,6 +884,45 @@ class ThreepidDeleteRestServlet(RestServlet):
         return 200, {"id_server_unbind_result": id_server_unbind_result}
 
 
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+    """
+    Raises a SynapseError if a given next_link value is invalid
+
+    next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+    option is either empty or contains a domain that matches the one in the given next_link
+
+    Args:
+        hs: The homeserver object
+        next_link: The next_link value given by the client
+
+    Raises:
+        SynapseError: If the next_link is invalid
+    """
+    valid = True
+
+    # Parse the contents of the URL
+    next_link_parsed = urlparse(next_link)
+
+    # Scheme must not point to the local drive
+    if next_link_parsed.scheme == "file":
+        valid = False
+
+    # If the domain whitelist is set, the domain must be in it
+    if (
+        valid
+        and hs.config.next_link_domain_whitelist is not None
+        and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+    ):
+        valid = False
+
+    if not valid:
+        raise SynapseError(
+            400,
+            "'next_link' domain not included in whitelist, or not http(s)",
+            errcode=Codes.INVALID_PARAM,
+        )
+
+
 class WhoamiRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/whoami$")
 
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d2826374a7..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -80,7 +80,7 @@ class MediaFilePaths:
         self, server_name, file_id, width, height, content_type, method
     ):
         top_level_type, sub_type = content_type.split("/")
-        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
             "remote_thumbnail",
             server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths:
 
     remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
 
+    # Legacy path that was used to store thumbnails previously.
+    # Should be removed after some time, when most of the thumbnails are stored
+    # using the new path.
+    def remote_media_thumbnail_rel_legacy(
+        self, server_name, file_id, width, height, content_type
+    ):
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        return os.path.join(
+            "remote_thumbnail",
+            server_name,
+            file_id[0:2],
+            file_id[2:4],
+            file_id[4:],
+            file_name,
+        )
+
     def remote_media_thumbnail_dir(self, server_name, file_id):
         return os.path.join(
             self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9a1b7779f7..69f353d46f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ from .media_storage import MediaStorage
 from .preview_url_resource import PreviewUrlResource
 from .storage_provider import StorageProviderWrapper
 from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
 logger = logging.getLogger(__name__)
@@ -460,13 +460,30 @@ class MediaRepository:
         return t_byte_source
 
     async def generate_local_exact_thumbnail(
-        self, media_id, t_width, t_height, t_method, t_type, url_cache
-    ):
+        self,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+        url_cache: str,
+    ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(None, media_id, url_cache=url_cache)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+                media_id,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
         t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
@@ -506,14 +523,36 @@ class MediaRepository:
 
             return output_path
 
+        # Could not generate thumbnail.
+        return None
+
     async def generate_remote_exact_thumbnail(
-        self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[str]:
         input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(server_name, file_id, url_cache=False)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
         t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
@@ -559,6 +598,9 @@ class MediaRepository:
 
             return output_path
 
+        # Could not generate thumbnail.
+        return None
+
     async def _generate_thumbnails(
         self,
         server_name: Optional[str],
@@ -590,7 +632,18 @@ class MediaRepository:
             FileInfo(server_name, file_id, url_cache=url_cache)
         )
 
-        thumbnailer = Thumbnailer(input_path)
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                media_type,
+                e,
+            )
+            return None
+
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3a352b5631..5681677fc9 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -147,6 +147,20 @@ class MediaStorage:
         if os.path.exists(local_path):
             return FileResponder(open(local_path, "rb"))
 
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return FileResponder(open(legacy_local_path, "rb"))
+
         for provider in self.storage_providers:
             res = await provider.fetch(path, file_info)  # type: Any
             if res:
@@ -170,6 +184,20 @@ class MediaStorage:
         if os.path.exists(local_path):
             return local_path
 
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return legacy_local_path
+
         dirname = os.path.dirname(local_path)
         if not os.path.exists(dirname):
             os.makedirs(dirname)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
 
 import logging
 
+from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
 
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
-            respond_404(request)
+            raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _select_or_generate_remote_thumbnail(
         self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
             await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
-            respond_404(request)
+            raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
         self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index d681bf7bf0..457ad6031c 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
 import logging
 from io import BytesIO
 
-from PIL import Image as Image
+from PIL import Image
 
 logger = logging.getLogger(__name__)
 
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
 }
 
 
+class ThumbnailError(Exception):
+    """An error occurred generating a thumbnail."""
+
+
 class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
     def __init__(self, input_path):
-        self.image = Image.open(input_path)
+        try:
+            self.image = Image.open(input_path)
+        except OSError as e:
+            # If an error occurs opening the image, a thumbnail won't be able to
+            # be generated.
+            raise ThumbnailError from e
+
         self.width, self.height = self.image.size
         self.transpose_method = None
         try: