diff --git a/synapse/util/async.py b/synapse/util/async_helpers.py
index a7094e2fb4..9b3f2f4b96 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async_helpers.py
@@ -188,62 +188,30 @@ class Linearizer(object):
# things blocked from executing.
self.key_to_defer = {}
- @defer.inlineCallbacks
def queue(self, key):
+ # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
+ # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
+ # propagated inside inlineCallbacks until Twisted 18.7)
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
- # When on of the things currently executing finishes it will callback
+ # When one of the things currently executing finishes it will callback
# this item so that it can continue executing.
if entry[0] >= self.max_count:
- new_defer = defer.Deferred()
- entry[1][new_defer] = 1
-
- logger.info(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key,
- )
- try:
- yield make_deferred_yieldable(new_defer)
- except Exception as e:
- if isinstance(e, CancelledError):
- logger.info(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name, key,
- )
- else:
- logger.warn(
- "Unexpected exception waiting for linearizer lock %r for key %r",
- self.name, key,
- )
-
- # we just have to take ourselves back out of the queue.
- del entry[1][new_defer]
- raise
-
- logger.info("Acquired linearizer lock %r for key %r", self.name, key)
- entry[0] += 1
-
- # if the code holding the lock completes synchronously, then it
- # will recursively run the next claimant on the list. That can
- # relatively rapidly lead to stack exhaustion. This is essentially
- # the same problem as http://twistedmatrix.com/trac/ticket/9304.
- #
- # In order to break the cycle, we add a cheeky sleep(0) here to
- # ensure that we fall back to the reactor between each iteration.
- #
- # (This needs to happen while we hold the lock, and the context manager's exit
- # code must be synchronous, so this is the only sensible place.)
- yield self._clock.sleep(0)
-
+ res = self._await_lock(key)
else:
logger.info(
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
)
entry[0] += 1
+ res = defer.succeed(None)
+
+ # once we successfully get the lock, we need to return a context manager which
+ # will release the lock.
@contextmanager
- def _ctx_manager():
+ def _ctx_manager(_):
try:
yield
finally:
@@ -264,7 +232,64 @@ class Linearizer(object):
# map.
del self.key_to_defer[key]
- defer.returnValue(_ctx_manager())
+ res.addCallback(_ctx_manager)
+ return res
+
+ def _await_lock(self, key):
+ """Helper for queue: adds a deferred to the queue
+
+ Assumes that we've already checked that we've reached the limit of the number
+ of lock-holders we allow. Creates a new deferred which is added to the list, and
+ adds some management around cancellations.
+
+ Returns the deferred, which will callback once we have secured the lock.
+
+ """
+ entry = self.key_to_defer[key]
+
+ logger.info(
+ "Waiting to acquire linearizer lock %r for key %r", self.name, key,
+ )
+
+ new_defer = make_deferred_yieldable(defer.Deferred())
+ entry[1][new_defer] = 1
+
+ def cb(_r):
+ logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+ entry[0] += 1
+
+ # if the code holding the lock completes synchronously, then it
+ # will recursively run the next claimant on the list. That can
+ # relatively rapidly lead to stack exhaustion. This is essentially
+ # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+ #
+ # In order to break the cycle, we add a cheeky sleep(0) here to
+ # ensure that we fall back to the reactor between each iteration.
+ #
+ # (This needs to happen while we hold the lock, and the context manager's exit
+ # code must be synchronous, so this is the only sensible place.)
+ return self._clock.sleep(0)
+
+ def eb(e):
+ logger.info("defer %r got err %r", new_defer, e)
+ if isinstance(e, CancelledError):
+ logger.info(
+ "Cancelling wait for linearizer lock %r for key %r",
+ self.name, key,
+ )
+
+ else:
+ logger.warn(
+ "Unexpected exception waiting for linearizer lock %r for key %r",
+ self.name, key,
+ )
+
+ # we just have to take ourselves back out of the queue.
+ del entry[1][new_defer]
+ return e
+
+ new_defer.addCallbacks(cb, eb)
+ return new_defer
class ReadWriteLock(object):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..187510576a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,7 +25,7 @@ from six import itervalues, string_types
from twisted.internet import defer
from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
- # If we're passed a cache_context then we'll want to call its invalidate()
- # whenever we are invalidated
+ # If we're passed a cache_context then we'll want to call its
+ # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- # cached is a dict arg -> deferred, where deferred results in a
- # 2-tuple (`arg`, `result`)
results = {}
- cached_defers = {}
- missing = []
+
+ def update_results_dict(res, arg):
+ results[arg] = res
+
+ # list of deferreds to wait for
+ cached_defers = []
+
+ missing = set()
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
- def cache_get(arg):
- return cache.get(arg, callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ return arg
else:
- key = list(keyargs)
+ keylist = list(keyargs)
- def cache_get(arg):
- key[self.list_pos] = arg
- return cache.get(tuple(key), callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ keylist[self.list_pos] = arg
+ return tuple(keylist)
for arg in list_args:
try:
- res = cache_get(arg)
-
+ res = cache.get(arg_to_cache_key(arg),
+ callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached_defers[arg] = res
+ res.addCallback(update_results_dict, arg)
+ cached_defers.append(res)
else:
results[arg] = res.get_result()
except KeyError:
- missing.append(arg)
+ missing.add(arg)
if missing:
+ # we need an observable deferred for each entry in the list,
+ # which we put in the cache. Each deferred resolves with the
+ # relevant result for that key.
+ deferreds_map = {}
+ for arg in missing:
+ deferred = defer.Deferred()
+ deferreds_map[arg] = deferred
+ key = arg_to_cache_key(arg)
+ observable = ObservableDeferred(deferred)
+ cache.set(key, observable, callback=invalidate_callback)
+
+ def complete_all(res):
+ # the wrapped function has completed. It returns a
+ # a dict. We can now resolve the observable deferreds in
+ # the cache and update our own result map.
+ for e in missing:
+ val = res.get(e, None)
+ deferreds_map[e].callback(val)
+ results[e] = val
+
+ def errback(f):
+ # the wrapped function has failed. Invalidate any cache
+ # entries we're supposed to be populating, and fail
+ # their deferreds.
+ for e in missing:
+ key = arg_to_cache_key(e)
+ cache.invalidate(key)
+ deferreds_map[e].errback(f)
+
+ # return the failure, to propagate to our caller.
+ return f
+
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = list(missing)
- ret_d = defer.maybeDeferred(
+ cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
- )
-
- ret_d = ObservableDeferred(ret_d)
-
- # We need to create deferreds for each arg in the list so that
- # we can insert the new deferred into the cache.
- for arg in missing:
- observer = ret_d.observe()
- observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
- observer = ObservableDeferred(observer)
-
- if num_args == 1:
- cache.set(
- arg, observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, arg)
- else:
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
-
- res = observer.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
-
- cached_defers[arg] = res
+ ).addCallbacks(complete_all, errback))
if cached_defers:
- def update_results_dict(res):
- results.update(res)
- return results
-
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- list(cached_defers.values()),
+ d = defer.gatherResults(
+ cached_defers,
consumeErrors=True,
- ).addCallback(update_results_dict).addErrback(
+ ).addCallbacks(
+ lambda _: results,
unwrapFirstError
- ))
+ )
+ return logcontext.make_deferred_yieldable(d)
else:
return results
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache.
Args:
- cache (Cache): The underlying cache to use.
+ cached_method_name (str): The name of the single-item lookup method.
+ This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a8491b42d5..afb03b2e1b 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -16,7 +16,7 @@ import logging
from twisted.internet import defer
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index d03678b8c8..8318db8d2c 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
class SnapshotCache(object):
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 581c6052ac..014edea971 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import binary_type, text_type
from canonicaljson import json
from frozendict import frozendict
@@ -26,7 +26,7 @@ def freeze(o):
if isinstance(o, frozendict):
return o
- if isinstance(o, string_types):
+ if isinstance(o, (binary_type, text_type)):
return o
try:
@@ -41,7 +41,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()})
- if isinstance(o, string_types):
+ if isinstance(o, (binary_type, text_type)):
return o
try:
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 8dcae50b39..a0c2d37610 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -385,7 +385,13 @@ class LoggingContextFilter(logging.Filter):
context = LoggingContext.current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
- context.copy_to(record)
+
+ # context should never be None, but if it somehow ends up being, then
+ # we end up in a death spiral of infinite loops, so let's check, for
+ # robustness' sake.
+ if context is not None:
+ context.copy_to(record)
+
return True
@@ -396,7 +402,9 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
- def __init__(self, new_context=LoggingContext.sentinel):
+ def __init__(self, new_context=None):
+ if new_context is None:
+ new_context = LoggingContext.sentinel
self.new_context = new_context
def __enter__(self):
@@ -526,7 +534,7 @@ _to_ignore = [
"synapse.util.logcontext",
"synapse.http.server",
"synapse.storage._base",
- "synapse.util.async",
+ "synapse.util.async_helpers",
]
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 62a00189cc..ef31458226 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -20,6 +20,8 @@ import time
from functools import wraps
from inspect import getcallargs
+from six import PY3
+
_TIME_FUNC_ID = 0
@@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
- lineno = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
+ if PY3:
+ lineno = f.__code__.co_firstlineno
+ pathname = f.__code__.co_filename
+ else:
+ lineno = f.func_code.co_firstlineno
+ pathname = f.func_code.co_filename
record = logging.LogRecord(
name=name,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 43d9db67ec..6f318c6a29 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,6 +16,7 @@
import random
import string
+from six import PY3
from six.moves import range
_string_with_symbols = (
@@ -34,6 +35,17 @@ def random_string_with_symbols(length):
def is_ascii(s):
+
+ if PY3:
+ if isinstance(s, bytes):
+ try:
+ s.decode('ascii').encode('ascii')
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
+ return True
+
try:
s.encode("ascii")
except UnicodeEncodeError:
@@ -49,6 +61,9 @@ def to_ascii(s):
If given None then will return None.
"""
+ if PY3:
+ return s
+
if s is None:
return None
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1fbcd41115..3baba3225a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -30,7 +30,7 @@ def get_version_string(module):
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
@@ -40,7 +40,7 @@ def get_version_string(module):
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
@@ -50,7 +50,7 @@ def get_version_string(module):
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
except subprocess.CalledProcessError:
git_commit = ""
@@ -60,7 +60,7 @@ def get_version_string(module):
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
- ).strip().endswith(dirty_string)
+ ).strip().decode('ascii').endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
@@ -77,8 +77,8 @@ def get_version_string(module):
"%s (%s)" % (
module.__version__, git_version,
)
- ).encode("ascii")
+ )
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__.encode("ascii")
+ return module.__version__
|