summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py6
-rw-r--r--synapse/util/async.py50
-rw-r--r--synapse/util/caches/descriptors.py45
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/distributor.py22
-rw-r--r--synapse/util/metrics.py23
-rw-r--r--synapse/util/ratelimitutils.py14
-rw-r--r--synapse/util/rlimit.py37
-rw-r--r--synapse/util/stringutils.py4
-rw-r--r--synapse/util/versionstring.py84
10 files changed, 235 insertions, 52 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 3b9da5b34a..2b3f0bef3c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.api.errors import SynapseError
 from synapse.util.logcontext import PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
@@ -49,9 +50,6 @@ class Clock(object):
         l.start(msec / 1000.0, now=False)
         return l
 
-    def stop_looping_call(self, loop):
-        loop.stop()
-
     def call_later(self, delay, callback, *args, **kwargs):
         """Call something later
 
@@ -83,7 +81,7 @@ class Clock(object):
 
         def timed_out_fn():
             try:
-                ret_deferred.errback(RuntimeError("Timed out"))
+                ret_deferred.errback(SynapseError(504, "Timed out"))
             except:
                 pass
 
diff --git a/synapse/util/async.py b/synapse/util/async.py
index cd4d90f3cf..0d6f48e2d8 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,9 +16,13 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext, preserve_fn
+from .logcontext import (
+    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+)
 from synapse.util import unwrapFirstError
 
+from contextlib import contextmanager
+
 
 @defer.inlineCallbacks
 def sleep(seconds):
@@ -137,3 +141,47 @@ def concurrently_execute(func, args, limit):
         preserve_fn(_concurrently_execute_inner)()
         for _ in xrange(limit)
     ], consumeErrors=True).addErrback(unwrapFirstError)
+
+
+class Linearizer(object):
+    """Linearizes access to resources based on a key. Useful to ensure only one
+    thing is happening at a time on a given resource.
+
+    Example:
+
+        with (yield linearizer.queue("test_key")):
+            # do some work.
+
+    """
+    def __init__(self):
+        self.key_to_defer = {}
+
+    @defer.inlineCallbacks
+    def queue(self, key):
+        # If there is already a deferred in the queue, we pull it out so that
+        # we can wait on it later.
+        # Then we replace it with a deferred that we resolve *after* the
+        # context manager has exited.
+        # We only return the context manager after the previous deferred has
+        # resolved.
+        # This all has the net effect of creating a chain of deferreds that
+        # wait for the previous deferred before starting their work.
+        current_defer = self.key_to_defer.get(key)
+
+        new_defer = defer.Deferred()
+        self.key_to_defer[key] = new_defer
+
+        if current_defer:
+            yield preserve_context_over_deferred(current_defer)
+
+        @contextmanager
+        def _ctx_manager():
+            try:
+                yield
+            finally:
+                new_defer.callback(None)
+                current_d = self.key_to_defer.get(key)
+                if current_d is new_defer:
+                    self.key_to_defer.pop(key, None)
+
+        defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 35544b19fd..758f5982b0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -167,7 +167,8 @@ class CacheDescriptor(object):
                 % (orig.__name__,)
             )
 
-        self.cache = Cache(
+    def __get__(self, obj, objtype=None):
+        cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
@@ -175,14 +176,12 @@ class CacheDescriptor(object):
             tree=self.tree,
         )
 
-    def __get__(self, obj, objtype=None):
-
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result_d = self.cache.get(cache_key)
+                cached_result_d = cache.get(cache_key)
 
                 observer = cached_result_d.observe()
                 if DEBUG_CACHES:
@@ -204,7 +203,7 @@ class CacheDescriptor(object):
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = self.cache.sequence
+                sequence = cache.sequence
 
                 ret = defer.maybeDeferred(
                     preserve_context_over_fn,
@@ -213,20 +212,21 @@ class CacheDescriptor(object):
                 )
 
                 def onErr(f):
-                    self.cache.invalidate(cache_key)
+                    cache.invalidate(cache_key)
                     return f
 
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                self.cache.update(sequence, cache_key, ret)
+                cache.update(sequence, cache_key, ret)
 
                 return preserve_context_over_deferred(ret.observe())
 
-        wrapped.invalidate = self.cache.invalidate
-        wrapped.invalidate_all = self.cache.invalidate_all
-        wrapped.invalidate_many = self.cache.invalidate_many
-        wrapped.prefill = self.cache.prefill
+        wrapped.invalidate = cache.invalidate
+        wrapped.invalidate_all = cache.invalidate_all
+        wrapped.invalidate_many = cache.invalidate_many
+        wrapped.prefill = cache.prefill
+        wrapped.cache = cache
 
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -240,11 +240,12 @@ class CacheListDescriptor(object):
     the list of missing keys to the wrapped fucntion.
     """
 
-    def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+                 inlineCallbacks=False):
         """
         Args:
             orig (function)
-            cache (Cache)
+            method_name (str); The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
             num_args (int)
             inlineCallbacks (bool): Whether orig is a generator that should
@@ -263,7 +264,7 @@ class CacheListDescriptor(object):
         self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
 
-        self.cache = cache
+        self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
@@ -277,11 +278,13 @@ class CacheListDescriptor(object):
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
-                % (self.list_name, cache.name,)
+                % (self.list_name, cached_method_name,)
             )
 
     def __get__(self, obj, objtype=None):
 
+        cache = getattr(obj, self.cached_method_name).cache
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@@ -297,14 +300,14 @@ class CacheListDescriptor(object):
                 key[self.list_pos] = arg
 
                 try:
-                    res = self.cache.get(tuple(key)).observe()
+                    res = cache.get(tuple(key)).observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
                     cached[arg] = res
                 except KeyError:
                     missing.append(arg)
 
             if missing:
-                sequence = self.cache.sequence
+                sequence = cache.sequence
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = missing
 
@@ -327,10 +330,10 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    self.cache.update(sequence, tuple(key), observer)
+                    cache.update(sequence, tuple(key), observer)
 
                     def invalidate(f, key):
-                        self.cache.invalidate(key)
+                        cache.invalidate(key)
                         return f
                     observer.addErrback(invalidate, tuple(key))
 
@@ -370,7 +373,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
     )
 
 
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -400,7 +403,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
     """
     return lambda orig: CacheListDescriptor(
         orig,
-        cache=cache,
+        cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
         inlineCallbacks=inlineCallbacks,
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index be310ba320..36686b479e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,7 +35,7 @@ class ResponseCache(object):
             return None
 
     def set(self, key, deferred):
-        result = ObservableDeferred(deferred)
+        result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
 
         def remove(r):
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 8875813de4..d7cccc06b1 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,7 +15,9 @@
 
 from twisted.internet import defer
 
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_fn
+)
 
 from synapse.util import unwrapFirstError
 
@@ -25,6 +27,24 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+def registered_user(distributor, user):
+    return distributor.fire("registered_user", user)
+
+
+def user_left_room(distributor, user, room_id):
+    return preserve_context_over_fn(
+        distributor.fire,
+        "user_left_room", user=user, room_id=room_id
+    )
+
+
+def user_joined_room(distributor, user, room_id):
+    return preserve_context_over_fn(
+        distributor.fire,
+        "user_joined_room", user=user, room_id=room_id
+    )
+
+
 class Distributor(object):
     """A central dispatch point for loosely-connected pieces of code to
     register, observe, and fire signals.
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c51b641125..e1f374807e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution(
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration"
+        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
     ]
 
     def __init__(self, clock, name):
@@ -58,14 +58,20 @@ class Measure(object):
         self.name = name
         self.start_context = None
         self.start = None
+        self.created_context = False
 
     def __enter__(self):
         self.start = self.clock.time_msec()
         self.start_context = LoggingContext.current_context()
-        if self.start_context:
-            self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
-            self.db_txn_count = self.start_context.db_txn_count
-            self.db_txn_duration = self.start_context.db_txn_duration
+        if not self.start_context:
+            logger.warn("Entered Measure without log context: %s", self.name)
+            self.start_context = LoggingContext("Measure")
+            self.start_context.__enter__()
+            self.created_context = True
+
+        self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+        self.db_txn_count = self.start_context.db_txn_count
+        self.db_txn_duration = self.start_context.db_txn_duration
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if exc_type is not None or not self.start_context:
@@ -91,7 +97,12 @@ class Measure(object):
 
         block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
         block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
-        block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+        block_db_txn_count.inc_by(
+            context.db_txn_count - self.db_txn_count, self.name
+        )
         block_db_txn_duration.inc_by(
             context.db_txn_duration - self.db_txn_duration, self.name
         )
+
+        if self.created_context:
+            self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 4076eed269..1101881a2d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -100,20 +100,6 @@ class _PerHostRatelimiter(object):
         self.current_processing = set()
         self.request_times = []
 
-    def is_empty(self):
-        time_now = self.clock.time_msec()
-        self.request_times[:] = [
-            r for r in self.request_times
-            if time_now - r < self.window_size
-        ]
-
-        return not (
-            self.ready_request_queue
-            or self.sleeping_requests
-            or self.current_processing
-            or self.request_times
-        )
-
     @contextlib.contextmanager
     def ratelimit(self):
         # `contextlib.contextmanager` takes a generator and turns it into a
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
new file mode 100644
index 0000000000..f4a9abf83f
--- /dev/null
+++ b/synapse/util/rlimit.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import resource
+import logging
+
+
+logger = logging.getLogger("synapse.app.homeserver")
+
+
+def change_resource_limit(soft_file_no):
+    try:
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+
+        if not soft_file_no:
+            soft_file_no = hard
+
+        resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
+        logger.info("Set file limit to: %d", soft_file_no)
+
+        resource.setrlimit(
+            resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
+        )
+    except (ValueError, resource.error) as e:
+        logger.warn("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index b490bb8725..a100f151d4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -21,10 +21,6 @@ _string_with_symbols = (
 )
 
 
-def origin_from_ucid(ucid):
-    return ucid.split("@", 1)[1]
-
-
 def random_string(length):
     return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
 
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
new file mode 100644
index 0000000000..a4f156cb3b
--- /dev/null
+++ b/synapse/util/versionstring.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import subprocess
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def get_version_string(name, module):
+    try:
+        null = open(os.devnull, 'w')
+        cwd = os.path.dirname(os.path.abspath(module.__file__))
+        try:
+            git_branch = subprocess.check_output(
+                ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
+                stderr=null,
+                cwd=cwd,
+            ).strip()
+            git_branch = "b=" + git_branch
+        except subprocess.CalledProcessError:
+            git_branch = ""
+
+        try:
+            git_tag = subprocess.check_output(
+                ['git', 'describe', '--exact-match'],
+                stderr=null,
+                cwd=cwd,
+            ).strip()
+            git_tag = "t=" + git_tag
+        except subprocess.CalledProcessError:
+            git_tag = ""
+
+        try:
+            git_commit = subprocess.check_output(
+                ['git', 'rev-parse', '--short', 'HEAD'],
+                stderr=null,
+                cwd=cwd,
+            ).strip()
+        except subprocess.CalledProcessError:
+            git_commit = ""
+
+        try:
+            dirty_string = "-this_is_a_dirty_checkout"
+            is_dirty = subprocess.check_output(
+                ['git', 'describe', '--dirty=' + dirty_string],
+                stderr=null,
+                cwd=cwd,
+            ).strip().endswith(dirty_string)
+
+            git_dirty = "dirty" if is_dirty else ""
+        except subprocess.CalledProcessError:
+            git_dirty = ""
+
+        if git_branch or git_tag or git_commit or git_dirty:
+            git_version = ",".join(
+                s for s in
+                (git_branch, git_tag, git_commit, git_dirty,)
+                if s
+            )
+
+            return (
+                "%s/%s (%s)" % (
+                    name, module.__version__, git_version,
+                )
+            ).encode("ascii")
+    except Exception as e:
+        logger.info("Failed to check for git repository: %s", e)
+
+    return ("%s/%s" % (name, module.__version__,)).encode("ascii")