summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/descriptors.py4
-rw-r--r--synapse/util/caches/dictionary_cache.py6
-rw-r--r--synapse/util/caches/expiringcache.py6
-rw-r--r--synapse/util/caches/lrucache.py43
-rw-r--r--synapse/util/file_consumer.py139
-rw-r--r--synapse/util/logcontext.py119
-rw-r--r--synapse/util/metrics.py75
-rw-r--r--synapse/util/retryutils.py12
-rw-r--r--synapse/util/threepids.py48
9 files changed, 388 insertions, 64 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..bf3a66eae4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -75,6 +75,7 @@ class Cache(object):
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d4105822b3..1709e8b429 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -132,9 +132,13 @@ class DictionaryCache(object):
                 self._update_or_insert(key, value, known_absent)
 
     def _update_or_insert(self, key, value, known_absent):
-        entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
+        # We pop and reinsert as we need to tell the cache the size may have
+        # changed
+
+        entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
         entry.known_absent.update(known_absent)
+        self.cache[key] = entry
 
     def _insert(self, key, value, known_absent):
         self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):
@@ -132,14 +154,21 @@ class LruCache(object):
         def cache_set(key, value, callbacks=[]):
             node = cache.get(key, None)
             if node is not None:
-                if value != node.value:
+                # We sometimes store large objects, e.g. dicts, which cause
+                # the inequality check to take a long time. So let's only do
+                # the check if we have some callbacks to call.
+                if node.callbacks and value != node.value:
                     for cb in node.callbacks:
                         cb()
                     node.callbacks.clear()
 
-                    if size_callback:
-                        cached_cache_len[0] -= size_callback(node.value)
-                        cached_cache_len[0] += size_callback(value)
+                # We don't bother to protect this by value != node.value as
+                # generally size_callback will be cheap compared with equality
+                # checks. (For example, taking the size of two dicts is quicker
+                # than comparing them for equality.)
+                if size_callback:
+                    cached_cache_len[0] -= size_callback(node.value)
+                    cached_cache_len[0] += size_callback(value)
 
                 node.callbacks.update(callbacks)
 
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..90a2608d6f
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+
+import Queue
+
+
+class BackgroundFileConsumer(object):
+    """A consumer that writes to a file like object. Supports both push
+    and pull producers
+
+    Args:
+        file_obj (file): The file like object to write to. Closed when
+            finished.
+    """
+
+    # For PushProducers pause if we have this many unwritten slices
+    _PAUSE_ON_QUEUE_SIZE = 5
+    # And resume once the size of the queue is less than this
+    _RESUME_ON_QUEUE_SIZE = 2
+
+    def __init__(self, file_obj):
+        self._file_obj = file_obj
+
+        # Producer we're registered with
+        self._producer = None
+
+        # True if PushProducer, false if PullProducer
+        self.streaming = False
+
+        # For PushProducers, indicates whether we've paused the producer and
+        # need to call resumeProducing before we get more data.
+        self._paused_producer = False
+
+        # Queue of slices of bytes to be written. When producer calls
+        # unregister a final None is sent.
+        self._bytes_queue = Queue.Queue()
+
+        # Deferred that is resolved when finished writing
+        self._finished_deferred = None
+
+        # If the _writer thread throws an exception it gets stored here.
+        self._write_exception = None
+
+    def registerProducer(self, producer, streaming):
+        """Part of IConsumer interface
+
+        Args:
+            producer (IProducer)
+            streaming (bool): True if push based producer, False if pull
+                based.
+        """
+        if self._producer:
+            raise Exception("registerProducer called twice")
+
+        self._producer = producer
+        self.streaming = streaming
+        self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+        if not streaming:
+            self._producer.resumeProducing()
+
+    def unregisterProducer(self):
+        """Part of IProducer interface
+        """
+        self._producer = None
+        if not self._finished_deferred.called:
+            self._bytes_queue.put_nowait(None)
+
+    def write(self, bytes):
+        """Part of IProducer interface
+        """
+        if self._write_exception:
+            raise self._write_exception
+
+        if self._finished_deferred.called:
+            raise Exception("consumer has closed")
+
+        self._bytes_queue.put_nowait(bytes)
+
+        # If this is a PushProducer and the queue is getting behind
+        # then we pause the producer.
+        if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+            self._paused_producer = True
+            self._producer.pauseProducing()
+
+    def _writer(self):
+        """This is run in a background thread to write to the file.
+        """
+        try:
+            while self._producer or not self._bytes_queue.empty():
+                # If we've paused the producer check if we should resume the
+                # producer.
+                if self._producer and self._paused_producer:
+                    if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+                        reactor.callFromThread(self._resume_paused_producer)
+
+                bytes = self._bytes_queue.get()
+
+                # If we get a None (or empty list) then that's a signal used
+                # to indicate we should check if we should stop.
+                if bytes:
+                    self._file_obj.write(bytes)
+
+                # If its a pull producer then we need to explicitly ask for
+                # more stuff.
+                if not self.streaming and self._producer:
+                    reactor.callFromThread(self._producer.resumeProducing)
+        except Exception as e:
+            self._write_exception = e
+            raise
+        finally:
+            self._file_obj.close()
+
+    def wait(self):
+        """Returns a deferred that resolves when finished writing to file
+        """
+        return make_deferred_yieldable(self._finished_deferred)
+
+    def _resume_paused_producer(self):
+        """Gets called if we should resume producing after being paused
+        """
+        if self._paused_producer and self._producer:
+            self._paused_producer = False
+            self._producer.resumeProducing()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..d660ec785b 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
+
     Args:
         name (str): Name for the context for debugging.
     """
 
     __slots__ = [
-        "previous_context", "name", "usage_start", "usage_end", "main_thread",
-        "__dict__", "tag", "alive",
+        "previous_context", "name", "ru_stime", "ru_utime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "usage_start", "usage_end",
+        "main_thread", "alive",
+        "request", "tag",
     ]
 
     thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
         def add_database_transaction(self, duration_ms):
             pass
 
+        def add_database_scheduled(self, sched_ms):
+            pass
+
         def __nonzero__(self):
             return False
 
@@ -94,9 +101,17 @@ class LoggingContext(object):
         self.ru_stime = 0.
         self.ru_utime = 0.
         self.db_txn_count = 0
-        self.db_txn_duration = 0.
+
+        # ms spent waiting for db txns, excluding scheduling time
+        self.db_txn_duration_ms = 0
+
+        # ms spent waiting for db txns to be scheduled
+        self.db_sched_duration_ms = 0
+
         self.usage_start = None
+        self.usage_end = None
         self.main_thread = threading.current_thread()
+        self.request = None
         self.tag = ""
         self.alive = True
 
@@ -105,7 +120,11 @@ class LoggingContext(object):
 
     @classmethod
     def current_context(cls):
-        """Get the current logging context from thread local storage"""
+        """Get the current logging context from thread local storage
+
+        Returns:
+            LoggingContext: the current logging context
+        """
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
     @classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
         self.alive = False
 
     def copy_to(self, record):
-        """Copy fields from this context to the record"""
-        for key, value in self.__dict__.items():
-            setattr(record, key, value)
+        """Copy logging fields from this context to a log record or
+        another LoggingContext
+        """
 
-        record.ru_utime, record.ru_stime = self.get_resource_usage()
+        # 'request' is the only field we currently use in the logger, so that's
+        # all we need to copy
+        record.request = self.request
 
     def start(self):
         if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
 
     def add_database_transaction(self, duration_ms):
         self.db_txn_count += 1
-        self.db_txn_duration += duration_ms / 1000.
+        self.db_txn_duration_ms += duration_ms
+
+    def add_database_scheduled(self, sched_ms):
+        """Record a use of the database pool
+
+        Args:
+            sched_ms (int): number of milliseconds it took us to get a
+                connection
+        """
+        self.db_sched_duration_ms += sched_ms
 
 
 class LoggingContextFilter(logging.Filter):
@@ -262,43 +292,43 @@ class PreserveLoggingContext(object):
 
 
 def preserve_fn(f):
-    """Wraps a function, to ensure that the current context is restored after
+    """Function decorator which wraps the function with run_in_background"""
+    def g(*args, **kwargs):
+        return run_in_background(f, *args, **kwargs)
+    return g
+
+
+def run_in_background(f, *args, **kwargs):
+    """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
     deferred returned by the funtion completes.
 
     Useful for wrapping functions that return a deferred which you don't yield
     on.
     """
-    def reset_context(result):
-        LoggingContext.set_current_context(LoggingContext.sentinel)
-        return result
-
-    def g(*args, **kwargs):
-        current = LoggingContext.current_context()
-        res = f(*args, **kwargs)
-        if isinstance(res, defer.Deferred) and not res.called:
-            # The function will have reset the context before returning, so
-            # we need to restore it now.
-            LoggingContext.set_current_context(current)
-
-            # The original context will be restored when the deferred
-            # completes, but there is nothing waiting for it, so it will
-            # get leaked into the reactor or some other function which
-            # wasn't expecting it. We therefore need to reset the context
-            # here.
-            #
-            # (If this feels asymmetric, consider it this way: we are
-            # effectively forking a new thread of execution. We are
-            # probably currently within a ``with LoggingContext()`` block,
-            # which is supposed to have a single entry and exit point. But
-            # by spawning off another deferred, we are effectively
-            # adding a new exit point.)
-            res.addBoth(reset_context)
-        return res
-    return g
+    current = LoggingContext.current_context()
+    res = f(*args, **kwargs)
+    if isinstance(res, defer.Deferred) and not res.called:
+        # The function will have reset the context before returning, so
+        # we need to restore it now.
+        LoggingContext.set_current_context(current)
+
+        # The original context will be restored when the deferred
+        # completes, but there is nothing waiting for it, so it will
+        # get leaked into the reactor or some other function which
+        # wasn't expecting it. We therefore need to reset the context
+        # here.
+        #
+        # (If this feels asymmetric, consider it this way: we are
+        # effectively forking a new thread of execution. We are
+        # probably currently within a ``with LoggingContext()`` block,
+        # which is supposed to have a single entry and exit point. But
+        # by spawning off another deferred, we are effectively
+        # adding a new exit point.)
+        res.addBoth(_set_context_cb, LoggingContext.sentinel)
+    return res
 
 
-@defer.inlineCallbacks
 def make_deferred_yieldable(deferred):
     """Given a deferred, make it follow the Synapse logcontext rules:
 
@@ -312,9 +342,16 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to preserve_fn.)
     """
-    with PreserveLoggingContext():
-        r = yield deferred
-    defer.returnValue(r)
+    if isinstance(deferred, defer.Deferred) and not deferred.called:
+        prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+        deferred.addBoth(_set_context_cb, prev_context)
+    return deferred
+
+
+def _set_context_cb(result, context):
+    """A callback function which just sets the logging context"""
+    LoggingContext.set_current_context(context)
+    return result
 
 
 # modules to ignore in `logcontext_tracer`
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-block_timer = metrics.register_distribution(
-    "block_timer",
-    labels=["block_name"]
+# total number of times we have hit this block
+block_counter = metrics.register_counter(
+    "block_count",
+    labels=["block_name"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_block_timer:count",
+            "_block_ru_utime:count",
+            "_block_ru_stime:count",
+            "_block_db_txn_count:count",
+            "_block_db_txn_duration:count",
+        )
+    )
+)
+
+block_timer = metrics.register_counter(
+    "block_time_seconds",
+    labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_timer:total",
+    ),
 )
 
-block_ru_utime = metrics.register_distribution(
-    "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+    "block_ru_utime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_utime:total",
+    ),
 )
 
-block_ru_stime = metrics.register_distribution(
-    "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+    "block_ru_stime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_stime:total",
+    ),
 )
 
-block_db_txn_count = metrics.register_distribution(
-    "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+    "block_db_txn_count", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_count:total",
+    ),
 )
 
-block_db_txn_duration = metrics.register_distribution(
-    "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+    "block_db_txn_duration_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_duration:total",
+    ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+    "block_db_sched_duration_seconds", labels=["block_name"],
 )
 
 
@@ -64,7 +101,9 @@ def measure_func(name):
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+        "ru_stime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "created_context",
     ]
 
     def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
 
         self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
         self.db_txn_count = self.start_context.db_txn_count
-        self.db_txn_duration = self.start_context.db_txn_duration
+        self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+        self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if isinstance(exc_type, Exception) or not self.start_context:
             return
 
         duration = self.clock.time_msec() - self.start
+
+        block_counter.inc(self.name)
         block_timer.inc_by(duration, self.name)
 
         context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
             context.db_txn_count - self.db_txn_count, self.name
         )
         block_db_txn_duration.inc_by(
-            context.db_txn_duration - self.db_txn_duration, self.name
+            (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+            self.name
+        )
+        block_db_sched_duration.inc_by(
+            (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+            self.name
         )
 
         if self.created_context:
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1adedbb361..47b0bb5eb3 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
 
 class NotRetryingDestination(Exception):
     def __init__(self, retry_last_ts, retry_interval, destination):
+        """Raised by the limiter (and federation client) to indicate that we are
+        are deliberately not attempting to contact a given server.
+
+        Args:
+            retry_last_ts (int): the unix ts in milliseconds of our last attempt
+                to contact the server.  0 indicates that the last attempt was
+                successful or that we've never actually attempted to connect.
+            retry_interval (int): the time in milliseconds to wait until the next
+                attempt.
+            destination (str): the domain in question
+        """
+
         msg = "Not retrying server %s." % (destination,)
         super(NotRetryingDestination, self).__init__(msg)
 
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+    """Checks whether a given format of 3PID is allowed to be used on this HS
+
+    Args:
+        hs (synapse.server.HomeServer): server
+        medium (str): 3pid medium - e.g. email, msisdn
+        address (str): address within that medium (e.g. "wotan@matrix.org")
+            msisdns need to first have been canonicalised
+    Returns:
+        bool: whether the 3PID medium/address is allowed to be added to this HS
+    """
+
+    if hs.config.allowed_local_3pids:
+        for constraint in hs.config.allowed_local_3pids:
+            logger.debug(
+                "Checking 3PID %s (%s) against %s (%s)",
+                address, medium, constraint['pattern'], constraint['medium'],
+            )
+            if (
+                medium == constraint['medium'] and
+                re.match(constraint['pattern'], address)
+            ):
+                return True
+    else:
+        return True
+
+    return False