summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--docs/log_contexts.rst11
-rwxr-xr-xsynapse/app/synctl.py47
-rw-r--r--synapse/storage/event_federation.py2
-rw-r--r--synapse/util/caches/descriptors.py30
-rw-r--r--synapse/util/logcontext.py23
-rw-r--r--tests/util/caches/test_descriptors.py91
6 files changed, 180 insertions, 24 deletions
diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst
index 8d04a973de..eb1784e700 100644
--- a/docs/log_contexts.rst
+++ b/docs/log_contexts.rst
@@ -204,9 +204,14 @@ That doesn't follow the rules, but we can fix it by wrapping it with
 This technique works equally for external functions which return deferreds,
 or deferreds we have made ourselves.
 
-XXX: think this is what ``preserve_context_over_deferred`` is supposed to do,
-though it is broken, in that it only restores the logcontext for the duration
-of the callbacks, which doesn't comply with the logcontext rules.
+You can also use ``logcontext.make_deferred_yieldable``, which just does the
+boilerplate for you, so the above could be written:
+
+.. code:: python
+
+    def sleep(seconds):
+        return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
+
 
 Fire-and-forget
 ---------------
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index c045588866..23eb6a1ec4 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -23,14 +23,27 @@ import signal
 import subprocess
 import sys
 import yaml
+import errno
+import time
 
 SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
 
 GREEN = "\x1b[1;32m"
+YELLOW = "\x1b[1;33m"
 RED = "\x1b[1;31m"
 NORMAL = "\x1b[m"
 
 
+def pid_running(pid):
+    try:
+        os.kill(pid, 0)
+        return True
+    except OSError, err:
+        if err.errno == errno.EPERM:
+            return True
+        return False
+
+
 def write(message, colour=NORMAL, stream=sys.stdout):
     if colour == NORMAL:
         stream.write(message + "\n")
@@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
         stream.write(colour + message + NORMAL + "\n")
 
 
+def abort(message, colour=RED, stream=sys.stderr):
+    write(message, colour, stream)
+    sys.exit(1)
+
+
 def start(configfile):
     write("Starting ...")
     args = SYNAPSE
@@ -45,7 +63,8 @@ def start(configfile):
 
     try:
         subprocess.check_call(args)
-        write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
+        write("started synapse.app.homeserver(%r)" %
+              (configfile,), colour=GREEN)
     except subprocess.CalledProcessError as e:
         write(
             "error starting (exit code: %d); see above for logs" % e.returncode,
@@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
 def stop(pidfile, app):
     if os.path.exists(pidfile):
         pid = int(open(pidfile).read())
-        os.kill(pid, signal.SIGTERM)
-        write("stopped %s" % (app,), colour=GREEN)
+        try:
+            os.kill(pid, signal.SIGTERM)
+            write("stopped %s" % (app,), colour=GREEN)
+        except OSError, err:
+            if err.errno == errno.ESRCH:
+                write("%s not running" % (app,), colour=YELLOW)
+            elif err.errno == errno.EPERM:
+                abort("Cannot stop %s: Operation not permitted" % (app,))
+            else:
+                abort("Cannot stop %s: Unknown error" % (app,))
 
 
 Worker = collections.namedtuple("Worker", [
@@ -190,7 +217,19 @@ def main():
         if start_stop_synapse:
             stop(pidfile, "synapse.app.homeserver")
 
-        # TODO: Wait for synapse to actually shutdown before starting it again
+    # Wait for synapse to actually shutdown before starting it again
+    if action == "restart":
+        running_pids = []
+        if start_stop_synapse and os.path.exists(pidfile):
+            running_pids.append(int(open(pidfile).read()))
+        for worker in workers:
+            if os.path.exists(worker.pidfile):
+                running_pids.append(int(open(worker.pidfile).read()))
+        if len(running_pids) > 0:
+            write("Waiting for process to exit before restarting...")
+            for running_pid in running_pids:
+                while pid_running(running_pid):
+                    time.sleep(0.2)
 
     if action == "start" or action == "restart":
         if start_stop_synapse:
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 43b5b49986..519059c306 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
         txn.execute(sql, (room_id, ))
 
         results = []
-        for event_id, depth in txn:
+        for event_id, depth in txn.fetchall():
             hashes = self._get_event_reference_hashes_txn(txn, event_id)
             prev_hashes = {
                 k: encode_base64(v) for k, v in hashes.items()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 19595df422..5c30ed235d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
 import logging
 
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
 
 from . import DEBUG_CACHES, register_cache
 
@@ -328,11 +325,9 @@ class CacheDescriptor(_CacheDescriptorBase):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return preserve_context_over_deferred(observer)
             except KeyError:
                 ret = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     obj, *args, **kwargs
                 )
 
@@ -342,10 +337,11 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, ret, callback=invalidate_callback)
+                result_d = ObservableDeferred(ret, consumeErrors=True)
+                cache.set(cache_key, result_d, callback=invalidate_callback)
+                observer = result_d.observe()
 
-                return preserve_context_over_deferred(ret.observe())
+            return logcontext.make_deferred_yieldable(observer)
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
@@ -362,7 +358,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
+    the list of missing keys to the wrapped function.
+
+    Once wrapped, the function returns either a Deferred which resolves to
+    the list of results, or (if all results were cached), just the list of
+    results.
     """
 
     def __init__(self, orig, cached_method_name, list_name, num_args=None,
@@ -433,8 +433,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
                 )
 
@@ -443,8 +442,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    with PreserveLoggingContext():
-                        observer = ret_d.observe()
+                    observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -471,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     results.update(res)
                     return results
 
-                return preserve_context_over_deferred(defer.gatherResults(
+                return logcontext.make_deferred_yieldable(defer.gatherResults(
                     cached_defers.values(),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index ff67b1d794..857afee7cb 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
 def preserve_context_over_deferred(deferred, context=None):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
+
+    Deprecated: this almost certainly doesn't do want you want, ie make
+    the deferred follow the synapse logcontext rules: try
+    ``make_deferred_yieldable`` instead.
     """
     if context is None:
         context = LoggingContext.current_context()
@@ -359,6 +363,25 @@ def preserve_fn(f):
     return g
 
 
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed (or is not actually a Deferred), essentially
+    does nothing (just returns another completed deferred with the
+    result/failure).
+
+    If the deferred has not yet completed, resets the logcontext before
+    returning a deferred. Then, when the deferred completes, restores the
+    current logcontext before running callbacks/errbacks.
+
+    (This is more-or-less the opposite operation to preserve_fn.)
+    """
+    with PreserveLoggingContext():
+        r = yield deferred
+    defer.returnValue(r)
+
+
 # modules to ignore in `logcontext_tracer`
 _to_ignore = [
     "synapse.util.logcontext",
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 419281054d..4414e86771 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -12,11 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
 import mock
+from synapse.api.errors import SynapseError
+from synapse.util import async
+from synapse.util import logcontext
 from twisted.internet import defer
 from synapse.util.caches import descriptors
 from tests import unittest
 
+logger = logging.getLogger(__name__)
+
 
 class DescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
@@ -84,3 +91,87 @@ class DescriptorTestCase(unittest.TestCase):
         r = yield obj.fn(2, 5)
         self.assertEqual(r, 'chips')
         obj.mock.assert_not_called()
+
+    def test_cache_logcontexts(self):
+        """Check that logcontexts are set and restored correctly when
+        using the cache."""
+
+        complete_lookup = defer.Deferred()
+
+        class Cls(object):
+            @descriptors.cached()
+            def fn(self, arg1):
+                @defer.inlineCallbacks
+                def inner_fn():
+                    with logcontext.PreserveLoggingContext():
+                        yield complete_lookup
+                    defer.returnValue(1)
+
+                return inner_fn()
+
+        @defer.inlineCallbacks
+        def do_lookup():
+            with logcontext.LoggingContext() as c1:
+                c1.name = "c1"
+                r = yield obj.fn(1)
+                self.assertEqual(logcontext.LoggingContext.current_context(),
+                                 c1)
+            defer.returnValue(r)
+
+        def check_result(r):
+            self.assertEqual(r, 1)
+
+        obj = Cls()
+
+        # set off a deferred which will do a cache lookup
+        d1 = do_lookup()
+        self.assertEqual(logcontext.LoggingContext.current_context(),
+                         logcontext.LoggingContext.sentinel)
+        d1.addCallback(check_result)
+
+        # and another
+        d2 = do_lookup()
+        self.assertEqual(logcontext.LoggingContext.current_context(),
+                         logcontext.LoggingContext.sentinel)
+        d2.addCallback(check_result)
+
+        # let the lookup complete
+        complete_lookup.callback(None)
+
+        return defer.gatherResults([d1, d2])
+
+    def test_cache_logcontexts_with_exception(self):
+        """Check that the cache sets and restores logcontexts correctly when
+        the lookup function throws an exception"""
+
+        class Cls(object):
+            @descriptors.cached()
+            def fn(self, arg1):
+                @defer.inlineCallbacks
+                def inner_fn():
+                    yield async.run_on_reactor()
+                    raise SynapseError(400, "blah")
+
+                return inner_fn()
+
+        @defer.inlineCallbacks
+        def do_lookup():
+            with logcontext.LoggingContext() as c1:
+                c1.name = "c1"
+                try:
+                    yield obj.fn(1)
+                    self.fail("No exception thrown")
+                except SynapseError:
+                    pass
+
+                self.assertEqual(logcontext.LoggingContext.current_context(),
+                                 c1)
+
+        obj = Cls()
+
+        # set off a deferred which will do a cache lookup
+        d1 = do_lookup()
+        self.assertEqual(logcontext.LoggingContext.current_context(),
+                         logcontext.LoggingContext.sentinel)
+
+        return d1