diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8709394b97..a859872ce2 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
-from synapse.util.async import create_observer
+from synapse.util.async import ObservableDeferred
from OpenSSL import crypto
@@ -111,6 +111,10 @@ class Keyring(object):
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
+ download = ObservableDeferred(
+ download,
+ consumeErrors=True
+ )
self.key_downloads[server_name] = download
@download.addBoth
@@ -118,7 +122,7 @@ class Keyring(object):
del self.key_downloads[server_name]
return ret
- r = yield create_observer(download)
+ r = yield download.observe()
defer.returnValue(r)
@defer.inlineCallbacks
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index 08c8d75af4..4af5f73878 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
-from synapse.util.async import create_observer
+from synapse.util.async import ObservableDeferred
import os
@@ -83,13 +83,17 @@ class BaseMediaResource(Resource):
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
+ download = ObservableDeferred(
+ download,
+ consumeErrors=True
+ )
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
- return create_observer(download)
+ return download.observe()
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..34acb14a6f 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -34,20 +34,56 @@ def run_on_reactor():
return sleep(0)
-def create_observer(deferred):
- """Creates a deferred that observes the result or failure of the given
- deferred *without* affecting the given deferred.
+class ObservableDeferred(object):
+ """Wraps a deferred object so that we can add observer deferreds. These
+ observer deferreds do not affect the callback chain of the original
+ deferred.
+
+ If consumeErrors is true errors will be captured from the origin deferred.
"""
- d = defer.Deferred()
- def callback(r):
- d.callback(r)
- return r
+ __slots__ = ["_deferred", "_observers", "_result"]
+
+ def __init__(self, deferred, consumeErrors=False):
+ object.__setattr__(self, "_deferred", deferred)
+ object.__setattr__(self, "_result", None)
+ object.__setattr__(self, "_observers", [])
+
+ def callback(r):
+ self._result = (True, r)
+ while self._observers:
+ try:
+ self._observers.pop().callback(r)
+ except:
+ pass
+ return r
+
+ def errback(f):
+ self._result = (False, f)
+ while self._observers:
+ try:
+ self._observers.pop().errback(f)
+ except:
+ pass
+
+ if consumeErrors:
+ return None
+ else:
+ return f
+
+ deferred.addCallbacks(callback, errback)
- def errback(f):
- d.errback(f)
- return f
+ def observe(self):
+ if not self._result:
+ d = defer.Deferred()
+ self._observers.append(d)
+ return d
+ else:
+ success, res = self._result
+ return defer.succeed(res) if success else defer.fail(res)
- deferred.addCallbacks(callback, errback)
+ def __getattr__(self, name):
+ return getattr(self._deferred, name)
- return d
+ def __setattr__(self, name, value):
+ setattr(self._deferred, name, value)
|