summary refs log tree commit diff
path: root/synapse/rest/media/v1
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/media/v1')
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py86
1 files changed, 45 insertions, 41 deletions
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 7907a9d17a..38e1afd34b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -20,6 +20,7 @@ from twisted.web.resource import Resource
 from synapse.api.errors import (
     SynapseError, Codes,
 )
+from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
 from synapse.util.stringutils import random_string
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.http.client import SpiderHttpClient
@@ -63,16 +64,15 @@ class PreviewUrlResource(Resource):
 
         self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
 
-        # simple memory cache mapping urls to OG metadata
-        self.cache = ExpiringCache(
+        # memory cache mapping urls to an ObservableDeferred returning
+        # JSON-encoded OG metadata
+        self._cache = ExpiringCache(
             cache_name="url_previews",
             clock=self.clock,
             # don't spider URLs more often than once an hour
             expiry_ms=60 * 60 * 1000,
         )
-        self.cache.start()
-
-        self.downloads = {}
+        self._cache.start()
 
         self._cleaner_loop = self.clock.looping_call(
             self._expire_url_cache_data, 10 * 1000
@@ -94,6 +94,7 @@ class PreviewUrlResource(Resource):
         else:
             ts = self.clock.time_msec()
 
+        # XXX: we could move this into _do_preview if we wanted.
         url_tuple = urlparse.urlsplit(url)
         for entry in self.url_preview_url_blacklist:
             match = True
@@ -126,14 +127,40 @@ class PreviewUrlResource(Resource):
                     Codes.UNKNOWN
                 )
 
-        # first check the memory cache - good to handle all the clients on this
-        # HS thundering away to preview the same URL at the same time.
-        og = self.cache.get(url)
-        if og:
-            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
-            return
+        # the in-memory cache:
+        # * ensures that only one request is active at a time
+        # * takes load off the DB for the thundering herds
+        # * also caches any failures (unlike the DB) so we don't keep
+        #    requesting the same endpoint
+
+        observable = self._cache.get(url)
+
+        if not observable:
+            download = preserve_fn(self._do_preview)(
+                url, requester.user, ts,
+            )
+            observable = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
+            self._cache[url] = observable
 
-        # then check the URL cache in the DB (which will also provide us with
+        og = yield make_deferred_yieldable(observable.observe())
+        respond_with_json_bytes(request, 200, og, send_cors=True)
+
+    @defer.inlineCallbacks
+    def _do_preview(self, url, user, ts):
+        """Check the db, and download the URL and build a preview
+
+        Args:
+            url (str):
+            user (str):
+            ts (int):
+
+        Returns:
+            Deferred[str]: json-encoded og data
+        """
+        # check the URL cache in the DB (which will also provide us with
         # historical previews, if we have any)
         cache_result = yield self.store.get_url_cache(url, ts)
         if (
@@ -141,32 +168,10 @@ class PreviewUrlResource(Resource):
             cache_result["expires_ts"] > ts and
             cache_result["response_code"] / 100 == 2
         ):
-            respond_with_json_bytes(
-                request, 200, cache_result["og"].encode('utf-8'),
-                send_cors=True
-            )
+            defer.returnValue(cache_result["og"])
             return
 
-        # Ensure only one download for a given URL is active at a time
-        download = self.downloads.get(url)
-        if download is None:
-            download = self._download_url(url, requester.user)
-            download = ObservableDeferred(
-                download,
-                consumeErrors=True
-            )
-            self.downloads[url] = download
-
-            @download.addBoth
-            def callback(media_info):
-                del self.downloads[url]
-                return media_info
-        media_info = yield download.observe()
-
-        # FIXME: we should probably update our cache now anyway, so that
-        # even if the OG calculation raises, we don't keep hammering on the
-        # remote server.  For now, leave it uncached to aid debugging OG
-        # calculation problems
+        media_info = yield self._download_url(url, user)
 
         logger.debug("got media_info of '%s'" % media_info)
 
@@ -212,7 +217,7 @@ class PreviewUrlResource(Resource):
             # just rely on the caching on the master request to speed things up.
             if 'og:image' in og and og['og:image']:
                 image_info = yield self._download_url(
-                    _rebase_url(og['og:image'], media_info['uri']), requester.user
+                    _rebase_url(og['og:image'], media_info['uri']), user
                 )
 
                 if _is_media(image_info['media_type']):
@@ -239,8 +244,7 @@ class PreviewUrlResource(Resource):
 
         logger.debug("Calculated OG for %s as %s" % (url, og))
 
-        # store OG in ephemeral in-memory cache
-        self.cache[url] = og
+        jsonog = json.dumps(og)
 
         # store OG in history-aware DB cache
         yield self.store.store_url_cache(
@@ -248,12 +252,12 @@ class PreviewUrlResource(Resource):
             media_info["response_code"],
             media_info["etag"],
             media_info["expires"] + media_info["created_ts"],
-            json.dumps(og),
+            jsonog,
             media_info["filesystem_id"],
             media_info["created_ts"],
         )
 
-        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+        defer.returnValue(jsonog)
 
     @defer.inlineCallbacks
     def _download_url(self, url, user):