diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 3b9da5b34a..c05b9450be 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -33,7 +34,7 @@ class Clock(object):
"""A small utility that obtains current time-of-day so that time may be
mocked during unit-tests.
- TODO(paul): Also move the sleep() functionallity into it
+ TODO(paul): Also move the sleep() functionality into it
"""
def time(self):
@@ -45,13 +46,18 @@ class Clock(object):
return int(self.time() * 1000)
def looping_call(self, f, msec):
+ """Call a function repeatedly.
+
+ Waits `msec` initially before calling `f` for the first time.
+
+ Args:
+ f(function): The function to call repeatedly.
+ msec(float): How long to wait between calls in milliseconds.
+ """
l = task.LoopingCall(f)
l.start(msec / 1000.0, now=False)
return l
- def stop_looping_call(self, loop):
- loop.stop()
-
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
@@ -83,7 +89,7 @@ class Clock(object):
def timed_out_fn():
try:
- ret_deferred.errback(RuntimeError("Timed out"))
+ ret_deferred.errback(SynapseError(504, "Timed out"))
except:
pass
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 640fae3890..347fb1e380 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,7 +16,12 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext
+from .logcontext import (
+ PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+)
+from synapse.util import unwrapFirstError
+
+from contextlib import contextmanager
@defer.inlineCallbacks
@@ -97,6 +102,15 @@ class ObservableDeferred(object):
def observers(self):
return self._observers
+ def has_called(self):
+ return self._result is not None
+
+ def has_succeeded(self):
+ return self._result is not None and self._result[0] is True
+
+ def get_result(self):
+ return self._result[1]
+
def __getattr__(self, name):
return getattr(self._deferred, name)
@@ -107,3 +121,159 @@ class ObservableDeferred(object):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)
+
+
+def concurrently_execute(func, args, limit):
+ """Executes the function with each argument conncurrently while limiting
+ the number of concurrent executions.
+
+ Args:
+ func (func): Function to execute, should return a deferred.
+ args (list): List of arguments to pass to func, each invocation of func
+ gets a signle argument.
+ limit (int): Maximum number of conccurent executions.
+
+ Returns:
+ deferred: Resolved when all function invocations have finished.
+ """
+ it = iter(args)
+
+ @defer.inlineCallbacks
+ def _concurrently_execute_inner():
+ try:
+ while True:
+ yield func(it.next())
+ except StopIteration:
+ pass
+
+ return preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(_concurrently_execute_inner)()
+ for _ in xrange(limit)
+ ], consumeErrors=True)).addErrback(unwrapFirstError)
+
+
+class Linearizer(object):
+ """Linearizes access to resources based on a key. Useful to ensure only one
+ thing is happening at a time on a given resource.
+
+ Example:
+
+ with (yield linearizer.queue("test_key")):
+ # do some work.
+
+ """
+ def __init__(self):
+ self.key_to_defer = {}
+
+ @defer.inlineCallbacks
+ def queue(self, key):
+ # If there is already a deferred in the queue, we pull it out so that
+ # we can wait on it later.
+ # Then we replace it with a deferred that we resolve *after* the
+ # context manager has exited.
+ # We only return the context manager after the previous deferred has
+ # resolved.
+ # This all has the net effect of creating a chain of deferreds that
+ # wait for the previous deferred before starting their work.
+ current_defer = self.key_to_defer.get(key)
+
+ new_defer = defer.Deferred()
+ self.key_to_defer[key] = new_defer
+
+ if current_defer:
+ with PreserveLoggingContext():
+ yield current_defer
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ current_d = self.key_to_defer.get(key)
+ if current_d is new_defer:
+ self.key_to_defer.pop(key, None)
+
+ defer.returnValue(_ctx_manager())
+
+
+class ReadWriteLock(object):
+ """A deferred style read write lock.
+
+ Example:
+
+ with (yield read_write_lock.read("test_key")):
+ # do some work
+ """
+
+ # IMPLEMENTATION NOTES
+ #
+ # We track the most recent queued reader and writer deferreds (which get
+ # resolved when they release the lock).
+ #
+ # Read: We know its safe to acquire a read lock when the latest writer has
+ # been resolved. The new reader is appeneded to the list of latest readers.
+ #
+ # Write: We know its safe to acquire the write lock when both the latest
+ # writers and readers have been resolved. The new writer replaces the latest
+ # writer.
+
+ def __init__(self):
+ # Latest readers queued
+ self.key_to_current_readers = {}
+
+ # Latest writer queued
+ self.key_to_current_writer = {}
+
+ @defer.inlineCallbacks
+ def read(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.setdefault(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ curr_readers.add(new_defer)
+
+ # We wait for the latest writer to finish writing. We can safely ignore
+ # any existing readers... as they're readers.
+ yield curr_writer
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ self.key_to_current_readers.get(key, set()).discard(new_defer)
+
+ defer.returnValue(_ctx_manager())
+
+ @defer.inlineCallbacks
+ def write(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.get(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ # We wait on all latest readers and writer.
+ to_wait_on = list(curr_readers)
+ if curr_writer:
+ to_wait_on.append(curr_writer)
+
+ # We can clear the list of current readers since the new writer waits
+ # for them to finish.
+ curr_readers.clear()
+ self.key_to_current_writer[key] = new_defer
+
+ yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ if self.key_to_current_writer[key] == new_defer:
+ self.key_to_current_writer.pop(key)
+
+ defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 1a14904194..ebd715c5dc 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -14,14 +14,81 @@
# limitations under the License.
import synapse.metrics
+from lrucache import LruCache
+import os
+
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
-cache_counter = metrics.register_cache(
- "cache",
- lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
- labels=["name"],
-)
+# cache_counter = metrics.register_cache(
+# "cache",
+# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+# labels=["name"],
+# )
+
+
+def register_cache(name, cache):
+ caches_by_name[name] = cache
+ return metrics.register_cache(
+ "cache",
+ lambda: len(cache),
+ name,
+ )
+
+
+_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
+caches_by_name["string_cache"] = _string_cache
+
+
+KNOWN_KEYS = {
+ key: key for key in
+ (
+ "auth_events",
+ "content",
+ "depth",
+ "event_id",
+ "hashes",
+ "origin",
+ "origin_server_ts",
+ "prev_events",
+ "room_id",
+ "sender",
+ "signatures",
+ "state_key",
+ "type",
+ "unsigned",
+ "user_id",
+ )
+}
+
+
+def intern_string(string):
+ """Takes a (potentially) unicode string and interns using custom cache
+ """
+ return _string_cache.setdefault(string, string)
+
+
+def intern_dict(dictionary):
+ """Takes a dictionary and interns well known keys and their values
+ """
+ return {
+ KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
+ for key, value in dictionary.items()
+ }
+
+
+def _intern_known_values(key, value):
+ intern_str_keys = ("event_id", "room_id")
+ intern_unicode_keys = ("sender", "user_id", "type", "state_key")
+
+ if key in intern_str_keys:
+ return intern(value.encode('ascii'))
+
+ if key in intern_unicode_keys:
+ return intern_string(value)
+
+ return value
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 35544b19fd..8dba61d49f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -22,17 +22,17 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
-from . import caches_by_name, DEBUG_CACHES, cache_counter
+from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
-
-from collections import OrderedDict
+from collections import namedtuple
import os
import functools
import inspect
import threading
+
logger = logging.getLogger(__name__)
@@ -43,23 +43,27 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class Cache(object):
+ __slots__ = (
+ "cache",
+ "max_entries",
+ "name",
+ "keylen",
+ "sequence",
+ "thread",
+ "metrics",
+ )
- def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
- if lru:
- cache_type = TreeCache if tree else dict
- self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type
- )
- self.max_entries = None
- else:
- self.cache = OrderedDict()
- self.max_entries = max_entries
+ def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+ cache_type = TreeCache if tree else dict
+ self.cache = LruCache(
+ max_size=max_entries, keylen=keylen, cache_type=cache_type
+ )
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -71,32 +75,28 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, key, default=_CacheSentinel):
- val = self.cache.get(key, _CacheSentinel)
+ def get(self, key, default=_CacheSentinel, callback=None):
+ val = self.cache.get(key, _CacheSentinel, callback=callback)
if val is not _CacheSentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return val
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
else:
return default
- def update(self, sequence, key, value):
+ def update(self, sequence, key, value, callback=None):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value)
-
- def prefill(self, key, value):
- if self.max_entries is not None:
- while len(self.cache) >= self.max_entries:
- self.cache.popitem(last=False)
+ self.prefill(key, value, callback=callback)
- self.cache[key] = value
+ def prefill(self, key, value, callback=None):
+ self.cache.set(key, value, callback=callback)
def invalidate(self, key):
self.check_thread()
@@ -141,9 +141,21 @@ class CacheDescriptor(object):
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example::
+
+ @cachedInlineCallbacks(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
+ defer.returnValue(r1 + r2)
+
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
- inlineCallbacks=False):
+ def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+ inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -155,34 +167,64 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
- self.lru = lru
self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
+ all_args = inspect.getargspec(orig)
+ self.arg_names = all_args.args[1:num_args + 1]
+
+ if "cache_context" in all_args.args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ try:
+ self.arg_names.remove("cache_context")
+ except ValueError:
+ pass
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
+
+ self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
+ " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
- self.cache = Cache(
+ def __get__(self, obj, objtype=None):
+ cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
- lru=self.lru,
tree=self.tree,
)
- def __get__(self, obj, objtype=None):
-
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
+ # Add temp cache_context so inspect.getcallargs doesn't explode
+ if self.add_cache_context:
+ kwargs["cache_context"] = None
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext(cache, cache_key)
+
try:
- cached_result_d = self.cache.get(cache_key)
+ cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -204,7 +246,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = self.cache.sequence
+ sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
@@ -213,20 +255,21 @@ class CacheDescriptor(object):
)
def onErr(f):
- self.cache.invalidate(cache_key)
+ cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- self.cache.update(sequence, cache_key, ret)
+ cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
- wrapped.invalidate = self.cache.invalidate
- wrapped.invalidate_all = self.cache.invalidate_all
- wrapped.invalidate_many = self.cache.invalidate_many
- wrapped.prefill = self.cache.prefill
+ wrapped.invalidate = cache.invalidate
+ wrapped.invalidate_all = cache.invalidate_all
+ wrapped.invalidate_many = cache.invalidate_many
+ wrapped.prefill = cache.prefill
+ wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
@@ -240,11 +283,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion.
"""
- def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+ def __init__(self, orig, cached_method_name, list_name, num_args=1,
+ inlineCallbacks=False):
"""
Args:
orig (function)
- cache (Cache)
+ method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
@@ -263,7 +307,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
- self.cache = cache
+ self.cached_method_name = cached_method_name
self.sentinel = object()
@@ -277,34 +321,45 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
- % (self.list_name, cache.name,)
+ % (self.list_name, cached_method_name,)
)
def __get__(self, obj, objtype=None):
+ cache = getattr(obj, self.cached_method_name).cache
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
- cached = {}
+ results = {}
+ cached_defers = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
- res = self.cache.get(tuple(key)).observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
+ res = cache.get(tuple(key), callback=invalidate_callback)
+ if not res.has_succeeded():
+ res = res.observe()
+ res.addCallback(lambda r, arg: (arg, r), arg)
+ cached_defers[arg] = res
+ else:
+ results[arg] = res.get_result()
except KeyError:
missing.append(arg)
if missing:
- sequence = self.cache.sequence
+ sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@@ -327,50 +382,67 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- self.cache.update(sequence, tuple(key), observer)
+ cache.update(
+ sequence, tuple(key), observer,
+ callback=invalidate_callback
+ )
def invalidate(f, key):
- self.cache.invalidate(key)
+ cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
+ cached_defers[arg] = res
+
+ if cached_defers:
+ def update_results_dict(res):
+ results.update(res)
+ return results
- return preserve_context_over_deferred(defer.gatherResults(
- cached.values(),
- consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
+ return preserve_context_over_deferred(defer.gatherResults(
+ cached_defers.values(),
+ consumeErrors=True,
+ ).addCallback(update_results_dict).addErrback(
+ unwrapFirstError
+ ))
+ else:
+ return results
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-def cached(max_entries=1000, num_args=1, lru=True, tree=False):
+class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
+ def invalidate(self):
+ self.cache.invalidate(self.key)
+
+
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
+ cache_context=cache_context,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru,
tree=tree,
inlineCallbacks=True,
+ cache_context=cache_context,
)
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -400,7 +472,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
"""
return lambda orig: CacheListDescriptor(
orig,
- cache=cache,
+ cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b0ca1bb79d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,7 +15,7 @@
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
-from . import caches_by_name, cache_counter
+from . import register_cache
import threading
import logging
@@ -43,7 +43,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -58,7 +58,7 @@ class DictionaryCache(object):
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
@@ -69,7 +69,7 @@ class DictionaryCache(object):
if k in entry.value
})
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return DictionaryEntry(False, {})
def invalidate(self, key):
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2b68c1ac93..080388958f 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
import logging
@@ -49,7 +49,7 @@ class ExpiringCache(object):
self._cache = {}
- caches_by_name[cache_name] = self._cache
+ self.metrics = register_cache(cache_name, self._cache)
def start(self):
if not self._expiry_ms:
@@ -78,9 +78,9 @@ class ExpiringCache(object):
def __getitem__(self, key):
try:
entry = self._cache[key]
- cache_counter.inc_hits(self._cache_name)
+ self.metrics.inc_hits()
except KeyError:
- cache_counter.inc_misses(self._cache_name)
+ self.metrics.inc_misses()
raise
if self._reset_expiry_on_get:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f7423f2fab..9c4c679175 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -29,19 +29,32 @@ def enumerate_leaves(node, depth):
yield m
+class _Node(object):
+ __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
+
+ def __init__(self, prev_node, next_node, key, value, callbacks=set()):
+ self.prev_node = prev_node
+ self.next_node = next_node
+ self.key = key
+ self.value = value
+ self.callbacks = callbacks
+
+
class LruCache(object):
"""
Least-recently-used cache.
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
+
+ Can also set callbacks on objects when getting/setting which are fired
+ when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
cache = cache_type()
self.cache = cache # Used for introspection.
- list_root = []
- list_root[:] = [list_root, list_root, None, None]
-
- PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
+ list_root = _Node(None, None, None, None)
+ list_root.next_node = list_root
+ list_root.prev_node = list_root
lock = threading.Lock()
@@ -53,65 +66,83 @@ class LruCache(object):
return inner
- def add_node(key, value):
+ def add_node(key, value, callbacks=set()):
prev_node = list_root
- next_node = prev_node[NEXT]
- node = [prev_node, next_node, key, value]
- prev_node[NEXT] = node
- next_node[PREV] = node
+ next_node = prev_node.next_node
+ node = _Node(prev_node, next_node, key, value, callbacks)
+ prev_node.next_node = node
+ next_node.prev_node = node
cache[key] = node
def move_node_to_front(node):
- prev_node = node[PREV]
- next_node = node[NEXT]
- prev_node[NEXT] = next_node
- next_node[PREV] = prev_node
+ prev_node = node.prev_node
+ next_node = node.next_node
+ prev_node.next_node = next_node
+ next_node.prev_node = prev_node
prev_node = list_root
- next_node = prev_node[NEXT]
- node[PREV] = prev_node
- node[NEXT] = next_node
- prev_node[NEXT] = node
- next_node[PREV] = node
+ next_node = prev_node.next_node
+ node.prev_node = prev_node
+ node.next_node = next_node
+ prev_node.next_node = node
+ next_node.prev_node = node
def delete_node(node):
- prev_node = node[PREV]
- next_node = node[NEXT]
- prev_node[NEXT] = next_node
- next_node[PREV] = prev_node
+ prev_node = node.prev_node
+ next_node = node.next_node
+ prev_node.next_node = next_node
+ next_node.prev_node = prev_node
+
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
@synchronized
- def cache_get(key, default=None):
+ def cache_get(key, default=None, callback=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
- return node[VALUE]
+ if callback:
+ node.callbacks.add(callback)
+ return node.value
else:
return default
@synchronized
- def cache_set(key, value):
+ def cache_set(key, value, callback=None):
node = cache.get(key, None)
if node is not None:
+ if value != node.value:
+ for cb in node.callbacks:
+ cb()
+ node.callbacks.clear()
+
+ if callback:
+ node.callbacks.add(callback)
+
move_node_to_front(node)
- node[VALUE] = value
+ node.value = value
else:
- add_node(key, value)
+ if callback:
+ callbacks = set([callback])
+ else:
+ callbacks = set()
+ add_node(key, value, callbacks)
if len(cache) > max_size:
- todelete = list_root[PREV]
+ todelete = list_root.prev_node
delete_node(todelete)
- cache.pop(todelete[KEY], None)
+ cache.pop(todelete.key, None)
@synchronized
def cache_set_default(key, value):
node = cache.get(key, None)
if node is not None:
- return node[VALUE]
+ return node.value
else:
add_node(key, value)
if len(cache) > max_size:
- todelete = list_root[PREV]
+ todelete = list_root.prev_node
delete_node(todelete)
- cache.pop(todelete[KEY], None)
+ cache.pop(todelete.key, None)
return value
@synchronized
@@ -119,8 +150,8 @@ class LruCache(object):
node = cache.get(key, None)
if node:
delete_node(node)
- cache.pop(node[KEY], None)
- return node[VALUE]
+ cache.pop(node.key, None)
+ return node.value
else:
return default
@@ -137,8 +168,11 @@ class LruCache(object):
@synchronized
def cache_clear():
- list_root[NEXT] = list_root
- list_root[PREV] = list_root
+ list_root.next_node = list_root
+ list_root.prev_node = list_root
+ for node in cache.values():
+ for cb in node.callbacks:
+ cb()
cache.clear()
@synchronized
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
new file mode 100644
index 0000000000..00af539880
--- /dev/null
+++ b/synapse/util/caches/response_cache.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.async import ObservableDeferred
+
+
+class ResponseCache(object):
+ """
+ This caches a deferred response. Until the deferred completes it will be
+ returned from the cache. This means that if the client retries the request
+ while the response is still being computed, that original response will be
+ used rather than trying to compute a new response.
+ """
+
+ def __init__(self, hs, timeout_ms=0):
+ self.pending_result_cache = {} # Requests that haven't finished yet.
+
+ self.clock = hs.get_clock()
+ self.timeout_sec = timeout_ms / 1000.
+
+ def get(self, key):
+ result = self.pending_result_cache.get(key)
+ if result is not None:
+ return result.observe()
+ else:
+ return None
+
+ def set(self, key, deferred):
+ result = ObservableDeferred(deferred, consumeErrors=True)
+ self.pending_result_cache[key] = result
+
+ def remove(r):
+ if self.timeout_sec:
+ self.clock.call_later(
+ self.timeout_sec,
+ self.pending_result_cache.pop, key, None,
+ )
+ else:
+ self.pending_result_cache.pop(key, None)
+ return r
+
+ result.addBoth(remove)
+ return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index ea8a74ca69..b72bb0ff02 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
from blist import sorteddict
@@ -42,7 +42,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- caches_by_name[self.name] = self._cache
+ self.metrics = register_cache(self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
@@ -53,19 +53,19 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
if stream_pos < latest_entity_change_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
def get_entities_changed(self, entities, stream_pos):
@@ -82,10 +82,10 @@ class StreamChangeCache(object):
self._cache[k] for k in keys[i:]
).intersection(entities)
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
else:
result = entities
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return result
@@ -121,3 +121,9 @@ class StreamChangeCache(object):
k, r = self._cache.popitem()
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
self._entity_to_key.pop(r, None)
+
+ def get_max_pos_of_last_change(self, entity):
+ """Returns an upper bound of the stream id of the last change to an
+ entity.
+ """
+ return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 03bc1401b7..c31585aea3 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,9 @@ class TreeCache(object):
self.size -= cnt
return popped
+ def values(self):
+ return [e.value for e in self.root.values()]
+
def __len__(self):
return self.size
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 8875813de4..e68f94ce77 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,7 +15,9 @@
from twisted.internet import defer
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_fn
+)
from synapse.util import unwrapFirstError
@@ -25,6 +27,20 @@ import logging
logger = logging.getLogger(__name__)
+def user_left_room(distributor, user, room_id):
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_left_room", user=user, room_id=room_id
+ )
+
+
+def user_joined_room(distributor, user, room_id):
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_joined_room", user=user, room_id=room_id
+ )
+
+
class Distributor(object):
"""A central dispatch point for loosely-connected pieces of code to
register, observe, and fire signals.
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
new file mode 100644
index 0000000000..45be47159a
--- /dev/null
+++ b/synapse/util/httpresourcetree.py
@@ -0,0 +1,98 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def create_resource_tree(desired_tree, root_resource):
+ """Create the resource tree for this Home Server.
+
+ This in unduly complicated because Twisted does not support putting
+ child resources more than 1 level deep at a time.
+
+ Args:
+ web_client (bool): True to enable the web client.
+ root_resource (twisted.web.resource.Resource): The root
+ resource to add the tree to.
+ Returns:
+ twisted.web.resource.Resource: the ``root_resource`` with a tree of
+ child resources added to it.
+ """
+
+ # ideally we'd just use getChild and putChild but getChild doesn't work
+ # unless you give it a Request object IN ADDITION to the name :/ So
+ # instead, we'll store a copy of this mapping so we can actually add
+ # extra resources to existing nodes. See self._resource_id for the key.
+ resource_mappings = {}
+ for full_path, res in desired_tree.items():
+ logger.info("Attaching %s to path %s", res, full_path)
+ last_resource = root_resource
+ for path_seg in full_path.split('/')[1:-1]:
+ if path_seg not in last_resource.listNames():
+ # resource doesn't exist, so make a "dummy resource"
+ child_resource = Resource()
+ last_resource.putChild(path_seg, child_resource)
+ res_id = _resource_id(last_resource, path_seg)
+ resource_mappings[res_id] = child_resource
+ last_resource = child_resource
+ else:
+ # we have an existing Resource, use that instead.
+ res_id = _resource_id(last_resource, path_seg)
+ last_resource = resource_mappings[res_id]
+
+ # ===========================
+ # now attach the actual desired resource
+ last_path_seg = full_path.split('/')[-1]
+
+ # if there is already a resource here, thieve its children and
+ # replace it
+ res_id = _resource_id(last_resource, last_path_seg)
+ if res_id in resource_mappings:
+ # there is a dummy resource at this path already, which needs
+ # to be replaced with the desired resource.
+ existing_dummy_resource = resource_mappings[res_id]
+ for child_name in existing_dummy_resource.listNames():
+ child_res_id = _resource_id(
+ existing_dummy_resource, child_name
+ )
+ child_resource = resource_mappings[child_res_id]
+ # steal the children
+ res.putChild(child_name, child_resource)
+
+ # finally, insert the desired resource in the right place
+ last_resource.putChild(last_path_seg, res)
+ res_id = _resource_id(last_resource, last_path_seg)
+ resource_mappings[res_id] = res
+
+ return root_resource
+
+
+def _resource_id(resource, path_seg):
+ """Construct an arbitrary resource ID so you can retrieve the mapping
+ later.
+
+ If you want to represent resource A putChild resource B with path C,
+ the mapping should looks like _resource_id(A,C) = B.
+
+ Args:
+ resource (Resource): The *parent* Resourceb
+ path_seg (str): The name of the child Resource to be attached.
+ Returns:
+ str: A unique string which can be a key to the child Resource.
+ """
+ return "%s-%s" % (resource, path_seg)
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 3fd5c3d9fd..d668e5a6b8 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -76,15 +76,26 @@ class JsonEncodedObject(object):
d.update(self.unrecognized_keys)
return d
+ def get_internal_dict(self):
+ d = {
+ k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+ if k in self.valid_keys
+ }
+ d.update(self.unrecognized_keys)
+ return d
+
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
-def _encode(obj):
+def _encode(obj, internal=False):
if type(obj) is list:
- return [_encode(o) for o in obj]
+ return [_encode(o, internal=internal) for o in obj]
if isinstance(obj, JsonEncodedObject):
- return obj.get_dict()
+ if internal:
+ return obj.get_internal_dict()
+ else:
+ return obj.get_dict()
return obj
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 5316259d15..6c83eb213d 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs):
return res
-def preserve_context_over_deferred(deferred):
+def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
- current_context = LoggingContext.current_context()
- d = _PreservingContextDeferred(current_context)
+ if context is None:
+ context = LoggingContext.current_context()
+ d = _PreservingContextDeferred(context)
deferred.chainDeferred(d)
return d
@@ -316,8 +317,13 @@ def preserve_fn(f):
def g(*args, **kwargs):
with PreserveLoggingContext(current):
- return f(*args, **kwargs)
-
+ res = f(*args, **kwargs)
+ if isinstance(res, defer.Deferred):
+ return preserve_context_over_deferred(
+ res, context=LoggingContext.sentinel
+ )
+ else:
+ return res
return g
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
new file mode 100644
index 0000000000..97e0f00b67
--- /dev/null
+++ b/synapse/util/manhole.py
@@ -0,0 +1,70 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.conch.manhole import ColoredManhole
+from twisted.conch.insults import insults
+from twisted.conch import manhole_ssh
+from twisted.cred import checkers, portal
+from twisted.conch.ssh.keys import Key
+
+PUBLIC_KEY = (
+ "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az"
+ "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS"
+ "kbh/C+BR3utDS555mV"
+)
+
+PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
+MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
+4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
+vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
+Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
+xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
+PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
+gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
+DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
+pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
+EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
+-----END RSA PRIVATE KEY-----"""
+
+
+def manhole(username, password, globals):
+ """Starts a ssh listener with password authentication using
+ the given username and password. Clients connecting to the ssh
+ listener will find themselves in a colored python shell with
+ the supplied globals.
+
+ Args:
+ username(str): The username ssh clients should auth with.
+ password(str): The password ssh clients should auth with.
+ globals(dict): The variables to expose in the shell.
+
+ Returns:
+ twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
+ """
+
+ checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
+ **{username: password}
+ )
+
+ rlm = manhole_ssh.TerminalRealm()
+ rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
+ ColoredManhole,
+ dict(globals, __name__="__console__")
+ )
+
+ factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
+ factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY)
+ factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+
+ return factory
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c51b641125..4ea930d3e8 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.logcontext import LoggingContext
import synapse.metrics
+from functools import wraps
import logging
@@ -47,10 +49,22 @@ block_db_txn_duration = metrics.register_distribution(
)
+def measure_func(name):
+ def wrapper(func):
+ @wraps(func)
+ @defer.inlineCallbacks
+ def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, name):
+ r = yield func(self, *args, **kwargs)
+ defer.returnValue(r)
+ return measured_func
+ return wrapper
+
+
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration"
+ "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
]
def __init__(self, clock, name):
@@ -58,17 +72,22 @@ class Measure(object):
self.name = name
self.start_context = None
self.start = None
+ self.created_context = False
def __enter__(self):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
- if self.start_context:
- self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
- self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ if not self.start_context:
+ self.start_context = LoggingContext("Measure")
+ self.start_context.__enter__()
+ self.created_context = True
+
+ self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+ self.db_txn_count = self.start_context.db_txn_count
+ self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb):
- if exc_type is not None or not self.start_context:
+ if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
@@ -78,8 +97,8 @@ class Measure(object):
if context != self.start_context:
logger.warn(
- "Context have unexpectedly changed from '%s' to '%s'. (%r)",
- context, self.start_context, self.name
+ "Context has unexpectedly changed from '%s' to '%s'. (%r)",
+ self.start_context, context, self.name
)
return
@@ -91,7 +110,12 @@ class Measure(object):
block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
- block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+ block_db_txn_count.inc_by(
+ context.db_txn_count - self.db_txn_count, self.name
+ )
block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name
)
+
+ if self.created_context:
+ self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 4076eed269..1101881a2d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -100,20 +100,6 @@ class _PerHostRatelimiter(object):
self.current_processing = set()
self.request_times = []
- def is_empty(self):
- time_now = self.clock.time_msec()
- self.request_times[:] = [
- r for r in self.request_times
- if time_now - r < self.window_size
- ]
-
- return not (
- self.ready_request_queue
- or self.sleeping_requests
- or self.current_processing
- or self.request_times
- )
-
@contextlib.contextmanager
def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 43cf11f3f6..e2de7fce91 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -121,15 +121,9 @@ class RetryDestinationLimiter(object):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
- def err(failure):
- logger.exception(
- "Failed to store set_destination_retry_timings",
- failure.value
- )
-
valid_err_code = False
- if exc_type is CodeMessageException:
- valid_err_code = 0 <= exc_val.code < 500
+ if exc_type is not None and issubclass(exc_type, CodeMessageException):
+ valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
if exc_type is None or valid_err_code:
# We connected successfully.
@@ -151,6 +145,15 @@ class RetryDestinationLimiter(object):
retry_last_ts = int(self.clock.time_msec())
- self.store.set_destination_retry_timings(
- self.destination, retry_last_ts, self.retry_interval
- ).addErrback(err)
+ @defer.inlineCallbacks
+ def store_retry_timings():
+ try:
+ yield self.store.set_destination_retry_timings(
+ self.destination, retry_last_ts, self.retry_interval
+ )
+ except:
+ logger.exception(
+ "Failed to store set_destination_retry_timings",
+ )
+
+ store_retry_timings()
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
new file mode 100644
index 0000000000..f4a9abf83f
--- /dev/null
+++ b/synapse/util/rlimit.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import resource
+import logging
+
+
+logger = logging.getLogger("synapse.app.homeserver")
+
+
+def change_resource_limit(soft_file_no):
+ try:
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+
+ if not soft_file_no:
+ soft_file_no = hard
+
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
+ logger.info("Set file limit to: %d", soft_file_no)
+
+ resource.setrlimit(
+ resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
+ )
+ except (ValueError, resource.error) as e:
+ logger.warn("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index b490bb8725..a100f151d4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -21,10 +21,6 @@ _string_with_symbols = (
)
-def origin_from_ucid(ucid):
- return ucid.split("@", 1)[1]
-
-
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
new file mode 100644
index 0000000000..52086df465
--- /dev/null
+++ b/synapse/util/versionstring.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import subprocess
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def get_version_string(module):
+ try:
+ null = open(os.devnull, 'w')
+ cwd = os.path.dirname(os.path.abspath(module.__file__))
+ try:
+ git_branch = subprocess.check_output(
+ ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
+ stderr=null,
+ cwd=cwd,
+ ).strip()
+ git_branch = "b=" + git_branch
+ except subprocess.CalledProcessError:
+ git_branch = ""
+
+ try:
+ git_tag = subprocess.check_output(
+ ['git', 'describe', '--exact-match'],
+ stderr=null,
+ cwd=cwd,
+ ).strip()
+ git_tag = "t=" + git_tag
+ except subprocess.CalledProcessError:
+ git_tag = ""
+
+ try:
+ git_commit = subprocess.check_output(
+ ['git', 'rev-parse', '--short', 'HEAD'],
+ stderr=null,
+ cwd=cwd,
+ ).strip()
+ except subprocess.CalledProcessError:
+ git_commit = ""
+
+ try:
+ dirty_string = "-this_is_a_dirty_checkout"
+ is_dirty = subprocess.check_output(
+ ['git', 'describe', '--dirty=' + dirty_string],
+ stderr=null,
+ cwd=cwd,
+ ).strip().endswith(dirty_string)
+
+ git_dirty = "dirty" if is_dirty else ""
+ except subprocess.CalledProcessError:
+ git_dirty = ""
+
+ if git_branch or git_tag or git_commit or git_dirty:
+ git_version = ",".join(
+ s for s in
+ (git_branch, git_tag, git_commit, git_dirty,)
+ if s
+ )
+
+ return (
+ "%s (%s)" % (
+ module.__version__, git_version,
+ )
+ ).encode("ascii")
+ except Exception as e:
+ logger.info("Failed to check for git repository: %s", e)
+
+ return module.__version__.encode("ascii")
|