summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/crypto/keyring.py8
-rw-r--r--synapse/rest/media/v1/base_resource.py8
-rw-r--r--synapse/util/async.py60
3 files changed, 60 insertions, 16 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8709394b97..a859872ce2 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
 
 from synapse.util.retryutils import get_retry_limiter
 
-from synapse.util.async import create_observer
+from synapse.util.async import ObservableDeferred
 
 from OpenSSL import crypto
 
@@ -111,6 +111,10 @@ class Keyring(object):
 
         if download is None:
             download = self._get_server_verify_key_impl(server_name, key_ids)
+            download = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
             self.key_downloads[server_name] = download
 
             @download.addBoth
@@ -118,7 +122,7 @@ class Keyring(object):
                 del self.key_downloads[server_name]
                 return ret
 
-        r = yield create_observer(download)
+        r = yield download.observe()
         defer.returnValue(r)
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index 08c8d75af4..4af5f73878 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
 from twisted.web.resource import Resource
 from twisted.protocols.basic import FileSender
 
-from synapse.util.async import create_observer
+from synapse.util.async import ObservableDeferred
 
 import os
 
@@ -83,13 +83,17 @@ class BaseMediaResource(Resource):
         download = self.downloads.get(key)
         if download is None:
             download = self._get_remote_media_impl(server_name, media_id)
+            download = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
             self.downloads[key] = download
 
             @download.addBoth
             def callback(media_info):
                 del self.downloads[key]
                 return media_info
-        return create_observer(download)
+        return download.observe()
 
     @defer.inlineCallbacks
     def _get_remote_media_impl(self, server_name, media_id):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..34acb14a6f 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -34,20 +34,56 @@ def run_on_reactor():
     return sleep(0)
 
 
-def create_observer(deferred):
-    """Creates a deferred that observes the result or failure of the given
-     deferred *without* affecting the given deferred.
+class ObservableDeferred(object):
+    """Wraps a deferred object so that we can add observer deferreds. These
+    observer deferreds do not affect the callback chain of the original
+    deferred.
+
+    If consumeErrors is true errors will be captured from the origin deferred.
     """
-    d = defer.Deferred()
 
-    def callback(r):
-        d.callback(r)
-        return r
+    __slots__ = ["_deferred", "_observers", "_result"]
+
+    def __init__(self, deferred, consumeErrors=False):
+        object.__setattr__(self, "_deferred", deferred)
+        object.__setattr__(self, "_result", None)
+        object.__setattr__(self, "_observers", [])
+
+        def callback(r):
+            self._result = (True, r)
+            while self._observers:
+                try:
+                    self._observers.pop().callback(r)
+                except:
+                    pass
+            return r
+
+        def errback(f):
+            self._result = (False, f)
+            while self._observers:
+                try:
+                    self._observers.pop().errback(f)
+                except:
+                    pass
+
+            if consumeErrors:
+                return None
+            else:
+                return f
+
+        deferred.addCallbacks(callback, errback)
 
-    def errback(f):
-        d.errback(f)
-        return f
+    def observe(self):
+        if not self._result:
+            d = defer.Deferred()
+            self._observers.append(d)
+            return d
+        else:
+            success, res = self._result
+            return defer.succeed(res) if success else defer.fail(res)
 
-    deferred.addCallbacks(callback, errback)
+    def __getattr__(self, name):
+        return getattr(self._deferred, name)
 
-    return d
+    def __setattr__(self, name, value):
+        setattr(self._deferred, name, value)