summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py16
-rw-r--r--synapse/util/async.py172
-rw-r--r--synapse/util/caches/__init__.py77
-rw-r--r--synapse/util/caches/descriptors.py202
-rw-r--r--synapse/util/caches/dictionary_cache.py8
-rw-r--r--synapse/util/caches/expiringcache.py8
-rw-r--r--synapse/util/caches/lrucache.py106
-rw-r--r--synapse/util/caches/response_cache.py55
-rw-r--r--synapse/util/caches/stream_change_cache.py22
-rw-r--r--synapse/util/caches/treecache.py3
-rw-r--r--synapse/util/distributor.py18
-rw-r--r--synapse/util/httpresourcetree.py98
-rw-r--r--synapse/util/jsonobject.py17
-rw-r--r--synapse/util/logcontext.py16
-rw-r--r--synapse/util/manhole.py70
-rw-r--r--synapse/util/metrics.py42
-rw-r--r--synapse/util/ratelimitutils.py14
-rw-r--r--synapse/util/retryutils.py25
-rw-r--r--synapse/util/rlimit.py37
-rw-r--r--synapse/util/stringutils.py4
-rw-r--r--synapse/util/versionstring.py84
21 files changed, 919 insertions, 175 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py

index 3b9da5b34a..c05b9450be 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import SynapseError from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -33,7 +34,7 @@ class Clock(object): """A small utility that obtains current time-of-day so that time may be mocked during unit-tests. - TODO(paul): Also move the sleep() functionallity into it + TODO(paul): Also move the sleep() functionality into it """ def time(self): @@ -45,13 +46,18 @@ class Clock(object): return int(self.time() * 1000) def looping_call(self, f, msec): + """Call a function repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f(function): The function to call repeatedly. + msec(float): How long to wait between calls in milliseconds. + """ l = task.LoopingCall(f) l.start(msec / 1000.0, now=False) return l - def stop_looping_call(self, loop): - loop.stop() - def call_later(self, delay, callback, *args, **kwargs): """Call something later @@ -83,7 +89,7 @@ class Clock(object): def timed_out_fn(): try: - ret_deferred.errback(RuntimeError("Timed out")) + ret_deferred.errback(SynapseError(504, "Timed out")) except: pass diff --git a/synapse/util/async.py b/synapse/util/async.py
index 640fae3890..347fb1e380 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py
@@ -16,7 +16,12 @@ from twisted.internet import defer, reactor -from .logcontext import PreserveLoggingContext +from .logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, +) +from synapse.util import unwrapFirstError + +from contextlib import contextmanager @defer.inlineCallbacks @@ -97,6 +102,15 @@ class ObservableDeferred(object): def observers(self): return self._observers + def has_called(self): + return self._result is not None + + def has_succeeded(self): + return self._result is not None and self._result[0] is True + + def get_result(self): + return self._result[1] + def __getattr__(self, name): return getattr(self._deferred, name) @@ -107,3 +121,159 @@ class ObservableDeferred(object): return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( id(self), self._result, self._deferred, ) + + +def concurrently_execute(func, args, limit): + """Executes the function with each argument conncurrently while limiting + the number of concurrent executions. + + Args: + func (func): Function to execute, should return a deferred. + args (list): List of arguments to pass to func, each invocation of func + gets a signle argument. + limit (int): Maximum number of conccurent executions. + + Returns: + deferred: Resolved when all function invocations have finished. + """ + it = iter(args) + + @defer.inlineCallbacks + def _concurrently_execute_inner(): + try: + while True: + yield func(it.next()) + except StopIteration: + pass + + return preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(_concurrently_execute_inner)() + for _ in xrange(limit) + ], consumeErrors=True)).addErrback(unwrapFirstError) + + +class Linearizer(object): + """Linearizes access to resources based on a key. Useful to ensure only one + thing is happening at a time on a given resource. + + Example: + + with (yield linearizer.queue("test_key")): + # do some work. + + """ + def __init__(self): + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + # If there is already a deferred in the queue, we pull it out so that + # we can wait on it later. + # Then we replace it with a deferred that we resolve *after* the + # context manager has exited. + # We only return the context manager after the previous deferred has + # resolved. + # This all has the net effect of creating a chain of deferreds that + # wait for the previous deferred before starting their work. + current_defer = self.key_to_defer.get(key) + + new_defer = defer.Deferred() + self.key_to_defer[key] = new_defer + + if current_defer: + with PreserveLoggingContext(): + yield current_defer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + current_d = self.key_to_defer.get(key) + if current_d is new_defer: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield curr_writer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 1a14904194..ebd715c5dc 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -14,14 +14,81 @@ # limitations under the License. import synapse.metrics +from lrucache import LruCache +import os + +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) DEBUG_CACHES = False metrics = synapse.metrics.get_metrics_for("synapse.util.caches") caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) +# cache_counter = metrics.register_cache( +# "cache", +# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, +# labels=["name"], +# ) + + +def register_cache(name, cache): + caches_by_name[name] = cache + return metrics.register_cache( + "cache", + lambda: len(cache), + name, + ) + + +_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) +caches_by_name["string_cache"] = _string_cache + + +KNOWN_KEYS = { + key: key for key in + ( + "auth_events", + "content", + "depth", + "event_id", + "hashes", + "origin", + "origin_server_ts", + "prev_events", + "room_id", + "sender", + "signatures", + "state_key", + "type", + "unsigned", + "user_id", + ) +} + + +def intern_string(string): + """Takes a (potentially) unicode string and interns using custom cache + """ + return _string_cache.setdefault(string, string) + + +def intern_dict(dictionary): + """Takes a dictionary and interns well known keys and their values + """ + return { + KNOWN_KEYS.get(key, key): _intern_known_values(key, value) + for key, value in dictionary.items() + } + + +def _intern_known_values(key, value): + intern_str_keys = ("event_id", "room_id") + intern_unicode_keys = ("sender", "user_id", "type", "state_key") + + if key in intern_str_keys: + return intern(value.encode('ascii')) + + if key in intern_unicode_keys: + return intern_string(value) + + return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 35544b19fd..8dba61d49f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -22,17 +22,17 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) -from . import caches_by_name, DEBUG_CACHES, cache_counter +from . import DEBUG_CACHES, register_cache from twisted.internet import defer - -from collections import OrderedDict +from collections import namedtuple import os import functools import inspect import threading + logger = logging.getLogger(__name__) @@ -43,23 +43,27 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) class Cache(object): + __slots__ = ( + "cache", + "max_entries", + "name", + "keylen", + "sequence", + "thread", + "metrics", + ) - def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): - if lru: - cache_type = TreeCache if tree else dict - self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type - ) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries + def __init__(self, name, max_entries=1000, keylen=1, tree=False): + cache_type = TreeCache if tree else dict + self.cache = LruCache( + max_size=max_entries, keylen=keylen, cache_type=cache_type + ) self.name = name self.keylen = keylen self.sequence = 0 self.thread = None - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -71,32 +75,28 @@ class Cache(object): "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) + def get(self, key, default=_CacheSentinel, callback=None): + val = self.cache.get(key, _CacheSentinel, callback=callback) if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return val - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() if default is _CacheSentinel: raise KeyError() else: return default - def update(self, sequence, key, value): + def update(self, sequence, key, value, callback=None): self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) - - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) + self.prefill(key, value, callback=callback) - self.cache[key] = value + def prefill(self, key, value, callback=None): + self.cache.set(key, value, callback=callback) def invalidate(self, key): self.check_thread() @@ -141,9 +141,21 @@ class CacheDescriptor(object): The wrapped function has another additional callable, called "prefill", which can be used to insert values into the cache specifically, without calling the calculation function. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example:: + + @cachedInlineCallbacks(cache_context=True) + def foo(self, key, cache_context): + r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) + defer.returnValue(r1 + r2) + """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, - inlineCallbacks=False): + def __init__(self, orig, max_entries=1000, num_args=1, tree=False, + inlineCallbacks=False, cache_context=False): max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.orig = orig @@ -155,34 +167,64 @@ class CacheDescriptor(object): self.max_entries = max_entries self.num_args = num_args - self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] + all_args = inspect.getargspec(orig) + self.arg_names = all_args.args[1:num_args + 1] + + if "cache_context" in all_args.args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + try: + self.arg_names.remove("cache_context") + except ValueError: + pass + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) + + self.add_cache_context = cache_context if len(self.arg_names) < self.num_args: raise Exception( "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" + " (@cached cannot key off of *args or **kwargs)" % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, - lru=self.lru, tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + + # Add temp cache_context so inspect.getcallargs doesn't explode + if self.add_cache_context: + kwargs["cache_context"] = None + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext(cache, cache_key) + try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key, callback=invalidate_callback) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +246,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +255,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret, callback=invalidate_callback) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +283,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +307,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,34 +321,45 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): + # If we're passed a cache_context then we'll want to call its invalidate() + # whenever we are invalidated + invalidate_callback = kwargs.pop("on_invalidate", None) + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) - cached = {} + results = {} + cached_defers = {} missing = [] for arg in list_args: key = list(keyargs) key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + res = cache.get(tuple(key), callback=invalidate_callback) + if not res.has_succeeded(): + res = res.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached_defers[arg] = res + else: + results[arg] = res.get_result() except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,50 +382,67 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update( + sequence, tuple(key), observer, + callback=invalidate_callback + ) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + cached_defers[arg] = res + + if cached_defers: + def update_results_dict(res): + results.update(res) + return results - return preserve_context_over_deferred(defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) + return preserve_context_over_deferred(defer.gatherResults( + cached_defers.values(), + consumeErrors=True, + ).addCallback(update_results_dict).addErrback( + unwrapFirstError + )) + else: + return results obj.__dict__[self.orig.__name__] = wrapped return wrapped -def cached(max_entries=1000, num_args=1, lru=True, tree=False): +class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): + def invalidate(self): + self.cache.invalidate(self.key) + + +def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, + cache_context=cache_context, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru, tree=tree, inlineCallbacks=True, + cache_context=cache_context, ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +472,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b0ca1bb79d 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py
@@ -15,7 +15,7 @@ from synapse.util.caches.lrucache import LruCache from collections import namedtuple -from . import caches_by_name, cache_counter +from . import register_cache import threading import logging @@ -43,7 +43,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -58,7 +58,7 @@ class DictionaryCache(object): def get(self, key, dict_keys=None): entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() if dict_keys is None: return DictionaryEntry(entry.full, dict(entry.value)) @@ -69,7 +69,7 @@ class DictionaryCache(object): if k in entry.value }) - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return DictionaryEntry(False, {}) def invalidate(self, key): diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2b68c1ac93..080388958f 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache import logging @@ -49,7 +49,7 @@ class ExpiringCache(object): self._cache = {} - caches_by_name[cache_name] = self._cache + self.metrics = register_cache(cache_name, self._cache) def start(self): if not self._expiry_ms: @@ -78,9 +78,9 @@ class ExpiringCache(object): def __getitem__(self, key): try: entry = self._cache[key] - cache_counter.inc_hits(self._cache_name) + self.metrics.inc_hits() except KeyError: - cache_counter.inc_misses(self._cache_name) + self.metrics.inc_misses() raise if self._reset_expiry_on_get: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f7423f2fab..9c4c679175 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -29,19 +29,32 @@ def enumerate_leaves(node, depth): yield m +class _Node(object): + __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"] + + def __init__(self, prev_node, next_node, key, value, callbacks=set()): + self.prev_node = prev_node + self.next_node = next_node + self.key = key + self.value = value + self.callbacks = callbacks + + class LruCache(object): """ Least-recently-used cache. Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. + + Can also set callbacks on objects when getting/setting which are fired + when that key gets invalidated/evicted. """ def __init__(self, max_size, keylen=1, cache_type=dict): cache = cache_type() self.cache = cache # Used for introspection. - list_root = [] - list_root[:] = [list_root, list_root, None, None] - - PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 + list_root = _Node(None, None, None, None) + list_root.next_node = list_root + list_root.prev_node = list_root lock = threading.Lock() @@ -53,65 +66,83 @@ class LruCache(object): return inner - def add_node(key, value): + def add_node(key, value, callbacks=set()): prev_node = list_root - next_node = prev_node[NEXT] - node = [prev_node, next_node, key, value] - prev_node[NEXT] = node - next_node[PREV] = node + next_node = prev_node.next_node + node = _Node(prev_node, next_node, key, value, callbacks) + prev_node.next_node = node + next_node.prev_node = node cache[key] = node def move_node_to_front(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node + prev_node = node.prev_node + next_node = node.next_node + prev_node.next_node = next_node + next_node.prev_node = prev_node prev_node = list_root - next_node = prev_node[NEXT] - node[PREV] = prev_node - node[NEXT] = next_node - prev_node[NEXT] = node - next_node[PREV] = node + next_node = prev_node.next_node + node.prev_node = prev_node + node.next_node = next_node + prev_node.next_node = node + next_node.prev_node = node def delete_node(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node + prev_node = node.prev_node + next_node = node.next_node + prev_node.next_node = next_node + next_node.prev_node = prev_node + + for cb in node.callbacks: + cb() + node.callbacks.clear() @synchronized - def cache_get(key, default=None): + def cache_get(key, default=None, callback=None): node = cache.get(key, None) if node is not None: move_node_to_front(node) - return node[VALUE] + if callback: + node.callbacks.add(callback) + return node.value else: return default @synchronized - def cache_set(key, value): + def cache_set(key, value, callback=None): node = cache.get(key, None) if node is not None: + if value != node.value: + for cb in node.callbacks: + cb() + node.callbacks.clear() + + if callback: + node.callbacks.add(callback) + move_node_to_front(node) - node[VALUE] = value + node.value = value else: - add_node(key, value) + if callback: + callbacks = set([callback]) + else: + callbacks = set() + add_node(key, value, callbacks) if len(cache) > max_size: - todelete = list_root[PREV] + todelete = list_root.prev_node delete_node(todelete) - cache.pop(todelete[KEY], None) + cache.pop(todelete.key, None) @synchronized def cache_set_default(key, value): node = cache.get(key, None) if node is not None: - return node[VALUE] + return node.value else: add_node(key, value) if len(cache) > max_size: - todelete = list_root[PREV] + todelete = list_root.prev_node delete_node(todelete) - cache.pop(todelete[KEY], None) + cache.pop(todelete.key, None) return value @synchronized @@ -119,8 +150,8 @@ class LruCache(object): node = cache.get(key, None) if node: delete_node(node) - cache.pop(node[KEY], None) - return node[VALUE] + cache.pop(node.key, None) + return node.value else: return default @@ -137,8 +168,11 @@ class LruCache(object): @synchronized def cache_clear(): - list_root[NEXT] = list_root - list_root[PREV] = list_root + list_root.next_node = list_root + list_root.prev_node = list_root + for node in cache.values(): + for cb in node.callbacks: + cb() cache.clear() @synchronized diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py new file mode 100644
index 0000000000..00af539880 --- /dev/null +++ b/synapse/util/caches/response_cache.py
@@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.async import ObservableDeferred + + +class ResponseCache(object): + """ + This caches a deferred response. Until the deferred completes it will be + returned from the cache. This means that if the client retries the request + while the response is still being computed, that original response will be + used rather than trying to compute a new response. + """ + + def __init__(self, hs, timeout_ms=0): + self.pending_result_cache = {} # Requests that haven't finished yet. + + self.clock = hs.get_clock() + self.timeout_sec = timeout_ms / 1000. + + def get(self, key): + result = self.pending_result_cache.get(key) + if result is not None: + return result.observe() + else: + return None + + def set(self, key, deferred): + result = ObservableDeferred(deferred, consumeErrors=True) + self.pending_result_cache[key] = result + + def remove(r): + if self.timeout_sec: + self.clock.call_later( + self.timeout_sec, + self.pending_result_cache.pop, key, None, + ) + else: + self.pending_result_cache.pop(key, None) + return r + + result.addBoth(remove) + return result.observe() diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index ea8a74ca69..b72bb0ff02 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache from blist import sorteddict @@ -42,7 +42,7 @@ class StreamChangeCache(object): self._cache = sorteddict() self._earliest_known_stream_pos = current_stream_pos self.name = name - caches_by_name[self.name] = self._cache + self.metrics = register_cache(self.name, self._cache) for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) @@ -53,19 +53,19 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos < self._earliest_known_stream_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False if stream_pos < latest_entity_change_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False def get_entities_changed(self, entities, stream_pos): @@ -82,10 +82,10 @@ class StreamChangeCache(object): self._cache[k] for k in keys[i:] ).intersection(entities) - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() else: result = entities - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return result @@ -121,3 +121,9 @@ class StreamChangeCache(object): k, r = self._cache.popitem() self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._entity_to_key.pop(r, None) + + def get_max_pos_of_last_change(self, entity): + """Returns an upper bound of the stream id of the last change to an + entity. + """ + return self._entity_to_key.get(entity, self._earliest_known_stream_pos) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 03bc1401b7..c31585aea3 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,9 @@ class TreeCache(object): self.size -= cnt return popped + def values(self): + return [e.value for e in self.root.values()] + def __len__(self): return self.size diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 8875813de4..e68f94ce77 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py
@@ -15,7 +15,9 @@ from twisted.internet import defer -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_fn +) from synapse.util import unwrapFirstError @@ -25,6 +27,20 @@ import logging logger = logging.getLogger(__name__) +def user_left_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) + + +def user_joined_room(distributor, user, room_id): + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) + + class Distributor(object): """A central dispatch point for loosely-connected pieces of code to register, observe, and fire signals. diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py new file mode 100644
index 0000000000..45be47159a --- /dev/null +++ b/synapse/util/httpresourcetree.py
@@ -0,0 +1,98 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource + +import logging + +logger = logging.getLogger(__name__) + + +def create_resource_tree(desired_tree, root_resource): + """Create the resource tree for this Home Server. + + This in unduly complicated because Twisted does not support putting + child resources more than 1 level deep at a time. + + Args: + web_client (bool): True to enable the web client. + root_resource (twisted.web.resource.Resource): The root + resource to add the tree to. + Returns: + twisted.web.resource.Resource: the ``root_resource`` with a tree of + child resources added to it. + """ + + # ideally we'd just use getChild and putChild but getChild doesn't work + # unless you give it a Request object IN ADDITION to the name :/ So + # instead, we'll store a copy of this mapping so we can actually add + # extra resources to existing nodes. See self._resource_id for the key. + resource_mappings = {} + for full_path, res in desired_tree.items(): + logger.info("Attaching %s to path %s", res, full_path) + last_resource = root_resource + for path_seg in full_path.split('/')[1:-1]: + if path_seg not in last_resource.listNames(): + # resource doesn't exist, so make a "dummy resource" + child_resource = Resource() + last_resource.putChild(path_seg, child_resource) + res_id = _resource_id(last_resource, path_seg) + resource_mappings[res_id] = child_resource + last_resource = child_resource + else: + # we have an existing Resource, use that instead. + res_id = _resource_id(last_resource, path_seg) + last_resource = resource_mappings[res_id] + + # =========================== + # now attach the actual desired resource + last_path_seg = full_path.split('/')[-1] + + # if there is already a resource here, thieve its children and + # replace it + res_id = _resource_id(last_resource, last_path_seg) + if res_id in resource_mappings: + # there is a dummy resource at this path already, which needs + # to be replaced with the desired resource. + existing_dummy_resource = resource_mappings[res_id] + for child_name in existing_dummy_resource.listNames(): + child_res_id = _resource_id( + existing_dummy_resource, child_name + ) + child_resource = resource_mappings[child_res_id] + # steal the children + res.putChild(child_name, child_resource) + + # finally, insert the desired resource in the right place + last_resource.putChild(last_path_seg, res) + res_id = _resource_id(last_resource, last_path_seg) + resource_mappings[res_id] = res + + return root_resource + + +def _resource_id(resource, path_seg): + """Construct an arbitrary resource ID so you can retrieve the mapping + later. + + If you want to represent resource A putChild resource B with path C, + the mapping should looks like _resource_id(A,C) = B. + + Args: + resource (Resource): The *parent* Resourceb + path_seg (str): The name of the child Resource to be attached. + Returns: + str: A unique string which can be a key to the child Resource. + """ + return "%s-%s" % (resource, path_seg) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 3fd5c3d9fd..d668e5a6b8 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py
@@ -76,15 +76,26 @@ class JsonEncodedObject(object): d.update(self.unrecognized_keys) return d + def get_internal_dict(self): + d = { + k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + if k in self.valid_keys + } + d.update(self.unrecognized_keys) + return d + def __str__(self): return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) -def _encode(obj): +def _encode(obj, internal=False): if type(obj) is list: - return [_encode(o) for o in obj] + return [_encode(o, internal=internal) for o in obj] if isinstance(obj, JsonEncodedObject): - return obj.get_dict() + if internal: + return obj.get_internal_dict() + else: + return obj.get_dict() return obj diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 5316259d15..6c83eb213d 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py
@@ -297,12 +297,13 @@ def preserve_context_over_fn(fn, *args, **kwargs): return res -def preserve_context_over_deferred(deferred): +def preserve_context_over_deferred(deferred, context=None): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. """ - current_context = LoggingContext.current_context() - d = _PreservingContextDeferred(current_context) + if context is None: + context = LoggingContext.current_context() + d = _PreservingContextDeferred(context) deferred.chainDeferred(d) return d @@ -316,8 +317,13 @@ def preserve_fn(f): def g(*args, **kwargs): with PreserveLoggingContext(current): - return f(*args, **kwargs) - + res = f(*args, **kwargs) + if isinstance(res, defer.Deferred): + return preserve_context_over_deferred( + res, context=LoggingContext.sentinel + ) + else: + return res return g diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py new file mode 100644
index 0000000000..97e0f00b67 --- /dev/null +++ b/synapse/util/manhole.py
@@ -0,0 +1,70 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.conch.manhole import ColoredManhole +from twisted.conch.insults import insults +from twisted.conch import manhole_ssh +from twisted.cred import checkers, portal +from twisted.conch.ssh.keys import Key + +PUBLIC_KEY = ( + "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az" + "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS" + "kbh/C+BR3utDS555mV" +) + +PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- +MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW +4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw +vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb +Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1 +xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8 +PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2 +gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu +DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML +pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP +EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg== +-----END RSA PRIVATE KEY-----""" + + +def manhole(username, password, globals): + """Starts a ssh listener with password authentication using + the given username and password. Clients connecting to the ssh + listener will find themselves in a colored python shell with + the supplied globals. + + Args: + username(str): The username ssh clients should auth with. + password(str): The password ssh clients should auth with. + globals(dict): The variables to expose in the shell. + + Returns: + twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` + """ + + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( + **{username: password} + ) + + rlm = manhole_ssh.TerminalRealm() + rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( + ColoredManhole, + dict(globals, __name__="__console__") + ) + + factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) + factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY) + factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY) + + return factory diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c51b641125..4ea930d3e8 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.logcontext import LoggingContext import synapse.metrics +from functools import wraps import logging @@ -47,10 +49,22 @@ block_db_txn_duration = metrics.register_distribution( ) +def measure_func(name): + def wrapper(func): + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, name): + r = yield func(self, *args, **kwargs) + defer.returnValue(r) + return measured_func + return wrapper + + class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration" + "ru_stime", "db_txn_count", "db_txn_duration", "created_context" ] def __init__(self, clock, name): @@ -58,17 +72,22 @@ class Measure(object): self.name = name self.start_context = None self.start = None + self.created_context = False def __enter__(self): self.start = self.clock.time_msec() self.start_context = LoggingContext.current_context() - if self.start_context: - self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() - self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + if not self.start_context: + self.start_context = LoggingContext("Measure") + self.start_context.__enter__() + self.created_context = True + + self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() + self.db_txn_count = self.start_context.db_txn_count + self.db_txn_duration = self.start_context.db_txn_duration def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is not None or not self.start_context: + if isinstance(exc_type, Exception) or not self.start_context: return duration = self.clock.time_msec() - self.start @@ -78,8 +97,8 @@ class Measure(object): if context != self.start_context: logger.warn( - "Context have unexpectedly changed from '%s' to '%s'. (%r)", - context, self.start_context, self.name + "Context has unexpectedly changed from '%s' to '%s'. (%r)", + self.start_context, context, self.name ) return @@ -91,7 +110,12 @@ class Measure(object): block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) - block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name) + block_db_txn_count.inc_by( + context.db_txn_count - self.db_txn_count, self.name + ) block_db_txn_duration.inc_by( context.db_txn_duration - self.db_txn_duration, self.name ) + + if self.created_context: + self.start_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 4076eed269..1101881a2d 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py
@@ -100,20 +100,6 @@ class _PerHostRatelimiter(object): self.current_processing = set() self.request_times = [] - def is_empty(self): - time_now = self.clock.time_msec() - self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size - ] - - return not ( - self.ready_request_queue - or self.sleeping_requests - or self.current_processing - or self.request_times - ) - @contextlib.contextmanager def ratelimit(self): # `contextlib.contextmanager` takes a generator and turns it into a diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 43cf11f3f6..e2de7fce91 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -121,15 +121,9 @@ class RetryDestinationLimiter(object): pass def __exit__(self, exc_type, exc_val, exc_tb): - def err(failure): - logger.exception( - "Failed to store set_destination_retry_timings", - failure.value - ) - valid_err_code = False - if exc_type is CodeMessageException: - valid_err_code = 0 <= exc_val.code < 500 + if exc_type is not None and issubclass(exc_type, CodeMessageException): + valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 if exc_type is None or valid_err_code: # We connected successfully. @@ -151,6 +145,15 @@ class RetryDestinationLimiter(object): retry_last_ts = int(self.clock.time_msec()) - self.store.set_destination_retry_timings( - self.destination, retry_last_ts, self.retry_interval - ).addErrback(err) + @defer.inlineCallbacks + def store_retry_timings(): + try: + yield self.store.set_destination_retry_timings( + self.destination, retry_last_ts, self.retry_interval + ) + except: + logger.exception( + "Failed to store set_destination_retry_timings", + ) + + store_retry_timings() diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py new file mode 100644
index 0000000000..f4a9abf83f --- /dev/null +++ b/synapse/util/rlimit.py
@@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import resource +import logging + + +logger = logging.getLogger("synapse.app.homeserver") + + +def change_resource_limit(soft_file_no): + try: + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + + if not soft_file_no: + soft_file_no = hard + + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard)) + logger.info("Set file limit to: %d", soft_file_no) + + resource.setrlimit( + resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) + ) + except (ValueError, resource.error) as e: + logger.warn("Failed to set file or core limit: %s", e) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index b490bb8725..a100f151d4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -21,10 +21,6 @@ _string_with_symbols = ( ) -def origin_from_ucid(ucid): - return ucid.split("@", 1)[1] - - def random_string(length): return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py new file mode 100644
index 0000000000..52086df465 --- /dev/null +++ b/synapse/util/versionstring.py
@@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import os +import logging + +logger = logging.getLogger(__name__) + + +def get_version_string(module): + try: + null = open(os.devnull, 'w') + cwd = os.path.dirname(os.path.abspath(module.__file__)) + try: + git_branch = subprocess.check_output( + ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + git_branch = "b=" + git_branch + except subprocess.CalledProcessError: + git_branch = "" + + try: + git_tag = subprocess.check_output( + ['git', 'describe', '--exact-match'], + stderr=null, + cwd=cwd, + ).strip() + git_tag = "t=" + git_tag + except subprocess.CalledProcessError: + git_tag = "" + + try: + git_commit = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + except subprocess.CalledProcessError: + git_commit = "" + + try: + dirty_string = "-this_is_a_dirty_checkout" + is_dirty = subprocess.check_output( + ['git', 'describe', '--dirty=' + dirty_string], + stderr=null, + cwd=cwd, + ).strip().endswith(dirty_string) + + git_dirty = "dirty" if is_dirty else "" + except subprocess.CalledProcessError: + git_dirty = "" + + if git_branch or git_tag or git_commit or git_dirty: + git_version = ",".join( + s for s in + (git_branch, git_tag, git_commit, git_dirty,) + if s + ) + + return ( + "%s (%s)" % ( + module.__version__, git_version, + ) + ).encode("ascii") + except Exception as e: + logger.info("Failed to check for git repository: %s", e) + + return module.__version__.encode("ascii")