summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/descriptors.py4
-rw-r--r--synapse/util/caches/expiringcache.py6
-rw-r--r--synapse/util/caches/lrucache.py28
-rw-r--r--synapse/util/file_consumer.py139
-rw-r--r--synapse/util/logcontext.py48
-rw-r--r--synapse/util/metrics.py75
-rw-r--r--synapse/util/retryutils.py12
-rw-r--r--synapse/util/threepids.py48
8 files changed, 333 insertions, 27 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..bf3a66eae4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -75,6 +75,7 @@ class Cache(object):
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..f088dd430e 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..90a2608d6f
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+
+import Queue
+
+
+class BackgroundFileConsumer(object):
+    """A consumer that writes to a file like object. Supports both push
+    and pull producers
+
+    Args:
+        file_obj (file): The file like object to write to. Closed when
+            finished.
+    """
+
+    # For PushProducers pause if we have this many unwritten slices
+    _PAUSE_ON_QUEUE_SIZE = 5
+    # And resume once the size of the queue is less than this
+    _RESUME_ON_QUEUE_SIZE = 2
+
+    def __init__(self, file_obj):
+        self._file_obj = file_obj
+
+        # Producer we're registered with
+        self._producer = None
+
+        # True if PushProducer, false if PullProducer
+        self.streaming = False
+
+        # For PushProducers, indicates whether we've paused the producer and
+        # need to call resumeProducing before we get more data.
+        self._paused_producer = False
+
+        # Queue of slices of bytes to be written. When producer calls
+        # unregister a final None is sent.
+        self._bytes_queue = Queue.Queue()
+
+        # Deferred that is resolved when finished writing
+        self._finished_deferred = None
+
+        # If the _writer thread throws an exception it gets stored here.
+        self._write_exception = None
+
+    def registerProducer(self, producer, streaming):
+        """Part of IConsumer interface
+
+        Args:
+            producer (IProducer)
+            streaming (bool): True if push based producer, False if pull
+                based.
+        """
+        if self._producer:
+            raise Exception("registerProducer called twice")
+
+        self._producer = producer
+        self.streaming = streaming
+        self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+        if not streaming:
+            self._producer.resumeProducing()
+
+    def unregisterProducer(self):
+        """Part of IProducer interface
+        """
+        self._producer = None
+        if not self._finished_deferred.called:
+            self._bytes_queue.put_nowait(None)
+
+    def write(self, bytes):
+        """Part of IProducer interface
+        """
+        if self._write_exception:
+            raise self._write_exception
+
+        if self._finished_deferred.called:
+            raise Exception("consumer has closed")
+
+        self._bytes_queue.put_nowait(bytes)
+
+        # If this is a PushProducer and the queue is getting behind
+        # then we pause the producer.
+        if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+            self._paused_producer = True
+            self._producer.pauseProducing()
+
+    def _writer(self):
+        """This is run in a background thread to write to the file.
+        """
+        try:
+            while self._producer or not self._bytes_queue.empty():
+                # If we've paused the producer check if we should resume the
+                # producer.
+                if self._producer and self._paused_producer:
+                    if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+                        reactor.callFromThread(self._resume_paused_producer)
+
+                bytes = self._bytes_queue.get()
+
+                # If we get a None (or empty list) then that's a signal used
+                # to indicate we should check if we should stop.
+                if bytes:
+                    self._file_obj.write(bytes)
+
+                # If its a pull producer then we need to explicitly ask for
+                # more stuff.
+                if not self.streaming and self._producer:
+                    reactor.callFromThread(self._producer.resumeProducing)
+        except Exception as e:
+            self._write_exception = e
+            raise
+        finally:
+            self._file_obj.close()
+
+    def wait(self):
+        """Returns a deferred that resolves when finished writing to file
+        """
+        return make_deferred_yieldable(self._finished_deferred)
+
+    def _resume_paused_producer(self):
+        """Gets called if we should resume producing after being paused
+        """
+        if self._paused_producer and self._producer:
+            self._paused_producer = False
+            self._producer.resumeProducing()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..94fa7cac98 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
+
     Args:
         name (str): Name for the context for debugging.
     """
 
     __slots__ = [
-        "previous_context", "name", "usage_start", "usage_end", "main_thread",
-        "__dict__", "tag", "alive",
+        "previous_context", "name", "ru_stime", "ru_utime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "usage_start", "usage_end",
+        "main_thread", "alive",
+        "request", "tag",
     ]
 
     thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
         def add_database_transaction(self, duration_ms):
             pass
 
+        def add_database_scheduled(self, sched_ms):
+            pass
+
         def __nonzero__(self):
             return False
 
@@ -94,9 +101,17 @@ class LoggingContext(object):
         self.ru_stime = 0.
         self.ru_utime = 0.
         self.db_txn_count = 0
-        self.db_txn_duration = 0.
+
+        # ms spent waiting for db txns, excluding scheduling time
+        self.db_txn_duration_ms = 0
+
+        # ms spent waiting for db txns to be scheduled
+        self.db_sched_duration_ms = 0
+
         self.usage_start = None
+        self.usage_end = None
         self.main_thread = threading.current_thread()
+        self.request = None
         self.tag = ""
         self.alive = True
 
@@ -105,7 +120,11 @@ class LoggingContext(object):
 
     @classmethod
     def current_context(cls):
-        """Get the current logging context from thread local storage"""
+        """Get the current logging context from thread local storage
+
+        Returns:
+            LoggingContext: the current logging context
+        """
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
     @classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
         self.alive = False
 
     def copy_to(self, record):
-        """Copy fields from this context to the record"""
-        for key, value in self.__dict__.items():
-            setattr(record, key, value)
+        """Copy logging fields from this context to a log record or
+        another LoggingContext
+        """
 
-        record.ru_utime, record.ru_stime = self.get_resource_usage()
+        # 'request' is the only field we currently use in the logger, so that's
+        # all we need to copy
+        record.request = self.request
 
     def start(self):
         if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
 
     def add_database_transaction(self, duration_ms):
         self.db_txn_count += 1
-        self.db_txn_duration += duration_ms / 1000.
+        self.db_txn_duration_ms += duration_ms
+
+    def add_database_scheduled(self, sched_ms):
+        """Record a use of the database pool
+
+        Args:
+            sched_ms (int): number of milliseconds it took us to get a
+                connection
+        """
+        self.db_sched_duration_ms += sched_ms
 
 
 class LoggingContextFilter(logging.Filter):
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-block_timer = metrics.register_distribution(
-    "block_timer",
-    labels=["block_name"]
+# total number of times we have hit this block
+block_counter = metrics.register_counter(
+    "block_count",
+    labels=["block_name"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_block_timer:count",
+            "_block_ru_utime:count",
+            "_block_ru_stime:count",
+            "_block_db_txn_count:count",
+            "_block_db_txn_duration:count",
+        )
+    )
+)
+
+block_timer = metrics.register_counter(
+    "block_time_seconds",
+    labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_timer:total",
+    ),
 )
 
-block_ru_utime = metrics.register_distribution(
-    "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+    "block_ru_utime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_utime:total",
+    ),
 )
 
-block_ru_stime = metrics.register_distribution(
-    "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+    "block_ru_stime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_stime:total",
+    ),
 )
 
-block_db_txn_count = metrics.register_distribution(
-    "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+    "block_db_txn_count", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_count:total",
+    ),
 )
 
-block_db_txn_duration = metrics.register_distribution(
-    "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+    "block_db_txn_duration_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_duration:total",
+    ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+    "block_db_sched_duration_seconds", labels=["block_name"],
 )
 
 
@@ -64,7 +101,9 @@ def measure_func(name):
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+        "ru_stime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "created_context",
     ]
 
     def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
 
         self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
         self.db_txn_count = self.start_context.db_txn_count
-        self.db_txn_duration = self.start_context.db_txn_duration
+        self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+        self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if isinstance(exc_type, Exception) or not self.start_context:
             return
 
         duration = self.clock.time_msec() - self.start
+
+        block_counter.inc(self.name)
         block_timer.inc_by(duration, self.name)
 
         context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
             context.db_txn_count - self.db_txn_count, self.name
         )
         block_db_txn_duration.inc_by(
-            context.db_txn_duration - self.db_txn_duration, self.name
+            (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+            self.name
+        )
+        block_db_sched_duration.inc_by(
+            (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+            self.name
         )
 
         if self.created_context:
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1adedbb361..47b0bb5eb3 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
 
 class NotRetryingDestination(Exception):
     def __init__(self, retry_last_ts, retry_interval, destination):
+        """Raised by the limiter (and federation client) to indicate that we are
+        are deliberately not attempting to contact a given server.
+
+        Args:
+            retry_last_ts (int): the unix ts in milliseconds of our last attempt
+                to contact the server.  0 indicates that the last attempt was
+                successful or that we've never actually attempted to connect.
+            retry_interval (int): the time in milliseconds to wait until the next
+                attempt.
+            destination (str): the domain in question
+        """
+
         msg = "Not retrying server %s." % (destination,)
         super(NotRetryingDestination, self).__init__(msg)
 
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+    """Checks whether a given format of 3PID is allowed to be used on this HS
+
+    Args:
+        hs (synapse.server.HomeServer): server
+        medium (str): 3pid medium - e.g. email, msisdn
+        address (str): address within that medium (e.g. "wotan@matrix.org")
+            msisdns need to first have been canonicalised
+    Returns:
+        bool: whether the 3PID medium/address is allowed to be added to this HS
+    """
+
+    if hs.config.allowed_local_3pids:
+        for constraint in hs.config.allowed_local_3pids:
+            logger.debug(
+                "Checking 3PID %s (%s) against %s (%s)",
+                address, medium, constraint['pattern'], constraint['medium'],
+            )
+            if (
+                medium == constraint['medium'] and
+                re.match(constraint['pattern'], address)
+            ):
+                return True
+    else:
+        return True
+
+    return False