summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/relations.py75
-rw-r--r--synapse/rest/client/room.py3
-rw-r--r--synapse/rest/client/user_directory.py4
-rw-r--r--synapse/rest/media/v1/preview_html.py39
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py23
5 files changed, 26 insertions, 118 deletions
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d9a6be43f7..c16078b187 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
-        self._event_serializer = hs.get_event_client_serializer()
-        self.event_handler = hs.get_event_handler()
+        self._relations_handler = hs.get_relations_handler()
 
     async def on_GET(
         self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id, requester.user.to_string(), allow_departed_users=True
-        )
-
-        # This gets the original event and checks that a) the event exists and
-        # b) the user is allowed to view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
         limit = parse_integer(request, "limit", default=5)
         direction = parse_string(
             request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
         if to_token_str:
             to_token = await StreamToken.from_string(self.store, to_token_str)
 
-        pagination_chunk = await self.store.get_relations_for_event(
+        result = await self._relations_handler.get_relations(
+            requester=requester,
             event_id=parent_id,
-            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
             to_token=to_token,
         )
 
-        events = await self.store.get_events_as_list(
-            [c["event_id"] for c in pagination_chunk.chunk]
-        )
-
-        now = self.clock.time_msec()
-        # Do not bundle aggregations when retrieving the original event because
-        # we want the content before relations are applied to it.
-        original_event = self._event_serializer.serialize_event(
-            event, now, bundle_aggregations=None
-        )
-        # The relations returned for the requested event do include their
-        # bundled aggregations.
-        aggregations = await self.store.get_bundled_aggregations(
-            events, requester.user.to_string()
-        )
-        serialized_events = self._event_serializer.serialize_events(
-            events, now, bundle_aggregations=aggregations
-        )
-
-        return_value = await pagination_chunk.to_dict(self.store)
-        return_value["chunk"] = serialized_events
-        return_value["original_event"] = original_event
-
-        return 200, return_value
+        return 200, result
 
 
 class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
-        self._event_serializer = hs.get_event_client_serializer()
-        self.event_handler = hs.get_event_handler()
+        self._relations_handler = hs.get_relations_handler()
 
     async def on_GET(
         self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id,
-            requester.user.to_string(),
-            allow_departed_users=True,
-        )
-
-        # This checks that a) the event exists and b) the user is allowed to
-        # view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
         if relation_type != RelationTypes.ANNOTATION:
             raise SynapseError(400, "Relation type must be 'annotation'")
 
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         if to_token_str:
             to_token = await StreamToken.from_string(self.store, to_token_str)
 
-        result = await self.store.get_relations_for_event(
+        result = await self._relations_handler.get_relations(
+            requester=requester,
             event_id=parent_id,
-            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
             to_token=to_token,
         )
 
-        events = await self.store.get_events_as_list(
-            [c["event_id"] for c in result.chunk]
-        )
-
-        now = self.clock.time_msec()
-        serialized_events = self._event_serializer.serialize_events(events, now)
-
-        return_value = await result.to_dict(self.store)
-        return_value["chunk"] = serialized_events
-
-        return 200, return_value
+        return 200, result
 
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 8a06ab8c5f..47e152c8cc 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
         self._store = hs.get_datastores().main
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
+        self._relations_handler = hs.get_relations_handler()
         self.auth = hs.get_auth()
 
     async def on_GET(
@@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
 
         if event:
             # Ensure there are bundled aggregations available.
-            aggregations = await self._store.get_bundled_aggregations(
+            aggregations = await self._relations_handler.get_bundled_aggregations(
                 [event], requester.user.to_string()
             )
 
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index a47d9bd01d..116c982ce6 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonMapping
 
 from ._base import client_patterns
 
@@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.user_directory_handler = hs.get_user_directory_handler()
 
-    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
         """Searches for users in directory
 
         Returns:
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 872a9e72e8..4cc9c66fbe 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -16,7 +16,6 @@ import itertools
 import logging
 import re
 from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
-from urllib import parse as urlparse
 
 if TYPE_CHECKING:
     from lxml import etree
@@ -144,9 +143,7 @@ def decode_body(
     return etree.fromstring(body, parser)
 
 
-def parse_html_to_open_graph(
-    tree: "etree.Element", media_uri: str
-) -> Dict[str, Optional[str]]:
+def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
     """
     Parse the HTML document into an Open Graph response.
 
@@ -155,7 +152,6 @@ def parse_html_to_open_graph(
 
     Args:
         tree: The parsed HTML document.
-        media_url: The URI used to download the body.
 
     Returns:
         The Open Graph response as a dictionary.
@@ -209,7 +205,7 @@ def parse_html_to_open_graph(
             "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
         )
         if meta_image:
-            og["og:image"] = rebase_url(meta_image[0], media_uri)
+            og["og:image"] = meta_image[0]
         else:
             # TODO: consider inlined CSS styles as well as width & height attribs
             images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@@ -320,37 +316,6 @@ def _iterate_over_text(
             )
 
 
-def rebase_url(url: str, base: str) -> str:
-    """
-    Resolves a potentially relative `url` against an absolute `base` URL.
-
-    For example:
-
-        >>> rebase_url("subpage", "https://example.com/foo/")
-        'https://example.com/foo/subpage'
-        >>> rebase_url("sibling", "https://example.com/foo")
-        'https://example.com/sibling'
-        >>> rebase_url("/bar", "https://example.com/foo/")
-        'https://example.com/bar'
-        >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
-        'https://alice.com/a'
-    """
-    base_parts = urlparse.urlparse(base)
-    # Convert the parsed URL to a list for (potential) modification.
-    url_parts = list(urlparse.urlparse(url))
-    # Add a scheme, if one does not exist.
-    if not url_parts[0]:
-        url_parts[0] = base_parts.scheme or "http"
-    # Fix up the hostname, if this is not a data URL.
-    if url_parts[0] != "data" and not url_parts[1]:
-        url_parts[1] = base_parts.netloc
-        # If the path does not start with a /, nest it under the base path's last
-        # directory.
-        if not url_parts[2].startswith("/"):
-            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
-    return urlparse.urlunparse(url_parts)
-
-
 def summarize_paragraphs(
     text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
 ) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 14ea88b240..d47af8ead6 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import shutil
 import sys
 import traceback
 from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
-from urllib import parse as urlparse
+from urllib.parse import urljoin, urlparse, urlsplit
 from urllib.request import urlopen
 
 import attr
@@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.rest.media.v1.preview_html import (
-    decode_body,
-    parse_html_to_open_graph,
-    rebase_url,
-)
+from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
 from synapse.types import JsonDict, UserID
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
@@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             ts = self.clock.time_msec()
 
         # XXX: we could move this into _do_preview if we wanted.
-        url_tuple = urlparse.urlsplit(url)
+        url_tuple = urlsplit(url)
         for entry in self.url_preview_url_blacklist:
             match = True
             for attrib in entry:
@@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
 
                 # Parse Open Graph information from the HTML in case the oEmbed
                 # response failed or is incomplete.
-                og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+                og_from_html = parse_html_to_open_graph(tree)
 
                 # Compile the Open Graph response by using the scraped
                 # information from the HTML and overlaying any information
@@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
         if "og:image" not in og or not og["og:image"]:
             return
 
+        # The image URL from the HTML might be relative to the previewed page,
+        # convert it to an URL which can be requested directly.
+        image_url = og["og:image"]
+        url_parts = urlparse(image_url)
+        if url_parts.scheme != "data":
+            image_url = urljoin(media_info.uri, image_url)
+
         # FIXME: it might be cleaner to use the same flow as the main /preview_url
         # request itself and benefit from the same caching etc.  But for now we
         # just rely on the caching on the master request to speed things up.
-        image_info = await self._handle_url(
-            rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
-        )
+        image_info = await self._handle_url(image_url, user, allow_data_urls=True)
 
         if _is_media(image_info.media_type):
             # TODO: make sure we don't choke on white-on-transparent images