summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-08-12 17:07:22 +0100
committerMark Haines <mark.haines@matrix.org>2015-08-12 17:21:14 +0100
commit998a72d4d9ec6e73000888dcdf51437ec427fbee (patch)
treef8fa9d5deb820b49eb3a216e194ee83e14fb9eda /synapse/util
parentBump the version of twisted needed for setup_requires to 15.2.1 (diff)
parentMerge pull request #220 from matrix-org/markjh/generate_keys (diff)
downloadsynapse-998a72d4d9ec6e73000888dcdf51437ec427fbee.tar.xz
Merge branch 'develop' into markjh/twisted-15
Conflicts:
	synapse/http/matrixfederationclient.py
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py8
-rw-r--r--synapse/util/async.py25
-rw-r--r--synapse/util/logcontext.py52
-rw-r--r--synapse/util/stringutils.py9
4 files changed, 69 insertions, 25 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 260714ccc2..07ff25cef3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -91,8 +91,12 @@ class Clock(object):
         with PreserveLoggingContext():
             return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 
-    def cancel_call_later(self, timer):
-        timer.cancel()
+    def cancel_call_later(self, timer, ignore_errs=False):
+        try:
+            timer.cancel()
+        except:
+            if not ignore_errs:
+                raise
 
     def time_bound_deferred(self, given_deferred, time_out):
         if given_deferred.called:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1c2044e5b4..7bf2d38bb8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -38,6 +38,9 @@ class ObservableDeferred(object):
     deferred.
 
     If consumeErrors is true errors will be captured from the origin deferred.
+
+    Cancelling or otherwise resolving an observer will not affect the original
+    ObservableDeferred.
     """
 
     __slots__ = ["_deferred", "_observers", "_result"]
@@ -45,10 +48,10 @@ class ObservableDeferred(object):
     def __init__(self, deferred, consumeErrors=False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
-        object.__setattr__(self, "_observers", [])
+        object.__setattr__(self, "_observers", set())
 
         def callback(r):
-            self._result = (True, r)
+            object.__setattr__(self, "_result", (True, r))
             while self._observers:
                 try:
                     self._observers.pop().callback(r)
@@ -57,7 +60,7 @@ class ObservableDeferred(object):
             return r
 
         def errback(f):
-            self._result = (False, f)
+            object.__setattr__(self, "_result", (False, f))
             while self._observers:
                 try:
                     self._observers.pop().errback(f)
@@ -74,14 +77,28 @@ class ObservableDeferred(object):
     def observe(self):
         if not self._result:
             d = defer.Deferred()
-            self._observers.append(d)
+
+            def remove(r):
+                self._observers.discard(d)
+                return r
+            d.addBoth(remove)
+
+            self._observers.add(d)
             return d
         else:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
+    def observers(self):
+        return self._observers
+
     def __getattr__(self, name):
         return getattr(self._deferred, name)
 
     def __setattr__(self, name, value):
         setattr(self._deferred, name, value)
+
+    def __repr__(self):
+        return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
+            id(self), self._result, self._deferred,
+        )
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a92d518b43..7e6062c1b8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
                 )
 
 
+class _PreservingContextDeferred(defer.Deferred):
+    """A deferred that ensures that all callbacks and errbacks are called with
+    the given logging context.
+    """
+    def __init__(self, context):
+        self._log_context = context
+        defer.Deferred.__init__(self)
+
+    def addCallbacks(self, callback, errback=None,
+                     callbackArgs=None, callbackKeywords=None,
+                     errbackArgs=None, errbackKeywords=None):
+        callback = self._wrap_callback(callback)
+        errback = self._wrap_callback(errback)
+        return defer.Deferred.addCallbacks(
+            self, callback,
+            errback=errback,
+            callbackArgs=callbackArgs,
+            callbackKeywords=callbackKeywords,
+            errbackArgs=errbackArgs,
+            errbackKeywords=errbackKeywords,
+        )
+
+    def _wrap_callback(self, f):
+        def g(res, *args, **kwargs):
+            with PreserveLoggingContext():
+                LoggingContext.thread_local.current_context = self._log_context
+                res = f(res, *args, **kwargs)
+            return res
+        return g
+
+
 def preserve_context_over_fn(fn, *args, **kwargs):
     """Takes a function and invokes it with the given arguments, but removes
     and restores the current logging context while doing so.
@@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
     """
-    d = defer.Deferred()
-
     current_context = LoggingContext.current_context()
-
-    def cb(res):
-        with PreserveLoggingContext():
-            LoggingContext.thread_local.current_context = current_context
-            res = d.callback(res)
-        return res
-
-    def eb(failure):
-        with PreserveLoggingContext():
-            LoggingContext.thread_local.current_context = current_context
-            res = d.errback(failure)
-        return res
-
-    if deferred.called:
-        return deferred
-
-    deferred.addCallbacks(cb, eb)
+    d = _PreservingContextDeferred(current_context)
+    deferred.chainDeferred(d)
     return d
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 52e66beaee..7a1e96af37 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -33,3 +33,12 @@ def random_string_with_symbols(length):
     return ''.join(
         random.choice(_string_with_symbols) for _ in xrange(length)
     )
+
+
+def is_ascii(s):
+    try:
+        s.encode("ascii")
+    except UnicodeDecodeError:
+        return False
+    else:
+        return True