summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4157.bugfix1
-rw-r--r--synapse/http/server.py8
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py22
-rw-r--r--tests/rest/media/v1/test_url_preview.py164
-rw-r--r--tests/server.py2
5 files changed, 187 insertions, 10 deletions
diff --git a/changelog.d/4157.bugfix b/changelog.d/4157.bugfix
new file mode 100644
index 0000000000..265514c3af
--- /dev/null
+++ b/changelog.d/4157.bugfix
@@ -0,0 +1 @@
+Loading URL previews from the DB cache on Postgres will no longer cause Unicode type errors when responding to the request, and URL previews will no longer fail if the remote server returns a Content-Type header with the chartype in quotes.
\ No newline at end of file
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b4b25cab19..6a427d96a6 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -468,13 +468,13 @@ def set_cors_headers(request):
     Args:
         request (twisted.web.http.Request): The http request to add CORs to.
     """
-    request.setHeader("Access-Control-Allow-Origin", "*")
+    request.setHeader(b"Access-Control-Allow-Origin", b"*")
     request.setHeader(
-        "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
+        b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
     )
     request.setHeader(
-        "Access-Control-Allow-Headers",
-        "Origin, X-Requested-With, Content-Type, Accept, Authorization"
+        b"Access-Control-Allow-Headers",
+        b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
     )
 
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 1a7bfd6b56..91d1dafe64 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import cgi
 import datetime
 import errno
@@ -24,6 +25,7 @@ import shutil
 import sys
 import traceback
 
+import six
 from six import string_types
 from six.moves import urllib_parse as urlparse
 
@@ -98,7 +100,7 @@ class PreviewUrlResource(Resource):
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = yield self.auth.get_user_by_req(request)
         url = parse_string(request, "url")
-        if "ts" in request.args:
+        if b"ts" in request.args:
             ts = parse_integer(request, "ts")
         else:
             ts = self.clock.time_msec()
@@ -180,7 +182,12 @@ class PreviewUrlResource(Resource):
             cache_result["expires_ts"] > ts and
             cache_result["response_code"] / 100 == 2
         ):
-            defer.returnValue(cache_result["og"])
+            # It may be stored as text in the database, not as bytes (such as
+            # PostgreSQL). If so, encode it back before handing it on.
+            og = cache_result["og"]
+            if isinstance(og, six.text_type):
+                og = og.encode('utf8')
+            defer.returnValue(og)
             return
 
         media_info = yield self._download_url(url, user)
@@ -213,14 +220,17 @@ class PreviewUrlResource(Resource):
         elif _is_html(media_info['media_type']):
             # TODO: somehow stop a big HTML tree from exploding synapse's RAM
 
-            file = open(media_info['filename'])
-            body = file.read()
-            file.close()
+            with open(media_info['filename'], 'rb') as file:
+                body = file.read()
 
             # clobber the encoding from the content-type, or default to utf-8
             # XXX: this overrides any <meta/> or XML charset headers in the body
             # which may pose problems, but so far seems to work okay.
-            match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
+            match = re.match(
+                r'.*; *charset="?(.*?)"?(;|$)',
+                media_info['media_type'],
+                re.I
+            )
             encoding = match.group(1) if match else "utf-8"
 
             og = decode_and_calc_og(body, media_info['uri'], encoding)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
new file mode 100644
index 0000000000..29579cf091
--- /dev/null
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from mock import Mock
+
+from twisted.internet.defer import Deferred
+
+from synapse.config.repository import MediaStorageProviderConfig
+from synapse.util.module_loader import load_module
+
+from tests import unittest
+
+
+class URLPreviewTests(unittest.HomeserverTestCase):
+
+    hijack_auth = True
+    user_id = "@test:user"
+
+    def make_homeserver(self, reactor, clock):
+
+        self.storage_path = self.mktemp()
+        os.mkdir(self.storage_path)
+
+        config = self.default_config()
+        config.url_preview_enabled = True
+        config.max_spider_size = 9999999
+        config.url_preview_url_blacklist = []
+        config.media_store_path = self.storage_path
+
+        provider_config = {
+            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        loaded = list(load_module(provider_config)) + [
+            MediaStorageProviderConfig(False, False, False)
+        ]
+
+        config.media_storage_providers = [loaded]
+
+        hs = self.setup_test_homeserver(config=config)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+
+        self.fetches = []
+
+        def get_file(url, output_stream, max_size):
+            """
+            Returns tuple[int,dict,str,int] of file length, response headers,
+            absolute URI, and response code.
+            """
+
+            def write_to(r):
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            d = Deferred()
+            d.addCallback(write_to)
+            self.fetches.append((d, url))
+            return d
+
+        client = Mock()
+        client.get_file = get_file
+
+        self.media_repo = hs.get_media_repository_resource()
+        preview_url = self.media_repo.children[b'preview_url']
+        preview_url.client = client
+        self.preview_url = preview_url
+
+    def test_cache_returns_correct_type(self):
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # We've made one fetch
+        self.assertEqual(len(self.fetches), 1)
+
+        end_content = (
+            b'<html><head>'
+            b'<meta property="og:title" content="~matrix~" />'
+            b'<meta property="og:description" content="hi" />'
+            b'</head></html>'
+        )
+
+        self.fetches[0][0].callback(
+            (
+                end_content,
+                (
+                    len(end_content),
+                    {
+                        b"Content-Length": [b"%d" % (len(end_content))],
+                        b"Content-Type": [b'text/html; charset="utf8"'],
+                    },
+                    "https://example.com",
+                    200,
+                ),
+            )
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+        # Check the cache returns the correct response
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # Only one fetch, still, since we'll lean on the cache
+        self.assertEqual(len(self.fetches), 1)
+
+        # Check the cache response has the same content
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
+
+        # Clear the in-memory cache
+        self.assertIn("matrix.org", self.preview_url._cache)
+        self.preview_url._cache.pop("matrix.org")
+        self.assertNotIn("matrix.org", self.preview_url._cache)
+
+        # Check the database cache returns the correct response
+        request, channel = self.make_request(
+            "GET", "url_preview?url=matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        # Only one fetch, still, since we'll lean on the cache
+        self.assertEqual(len(self.fetches), 1)
+
+        # Check the cache response has the same content
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
+        )
diff --git a/tests/server.py b/tests/server.py
index 984cfe26d4..7919a1f124 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -57,6 +57,8 @@ class FakeChannel(object):
         self.result["headers"] = headers
 
     def write(self, content):
+        assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+
         if "body" not in self.result:
             self.result["body"] = b""