summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py2
-rw-r--r--synapse/util/expiringcache.py115
-rw-r--r--synapse/util/retryutils.py153
3 files changed, 268 insertions, 2 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index e77eba90ad..79109d0b19 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -99,8 +99,6 @@ class Clock(object):
             except:
                 pass
 
-            return res
-
         given_deferred.addCallbacks(callback=sucess, errback=err)
 
         timer = self.call_later(time_out, timed_out_fn)
diff --git a/synapse/util/expiringcache.py b/synapse/util/expiringcache.py
new file mode 100644
index 0000000000..1c7859297a
--- /dev/null
+++ b/synapse/util/expiringcache.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExpiringCache(object):
+    def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
+                 reset_expiry_on_get=False):
+        """
+        Args:
+            cache_name (str): Name of this cache, used for logging.
+            clock (Clock)
+            max_len (int): Max size of dict. If the dict grows larger than this
+                then the oldest items get automatically evicted. Default is 0,
+                which indicates there is no max limit.
+            expiry_ms (int): How long before an item is evicted from the cache
+                in milliseconds. Default is 0, indicating items never get
+                evicted based on time.
+            reset_expiry_on_get (bool): If true, will reset the expiry time for
+                an item on access. Defaults to False.
+
+        """
+        self._cache_name = cache_name
+
+        self._clock = clock
+
+        self._max_len = max_len
+        self._expiry_ms = expiry_ms
+
+        self._reset_expiry_on_get = reset_expiry_on_get
+
+        self._cache = {}
+
+    def start(self):
+        if not self._expiry_ms:
+            # Don't bother starting the loop if things never expire
+            return
+
+        def f():
+            self._prune_cache()
+
+        self._clock.looping_call(f, self._expiry_ms/2)
+
+    def __setitem__(self, key, value):
+        now = self._clock.time_msec()
+        self._cache[key] = _CacheEntry(now, value)
+
+        # Evict if there are now too many items
+        if self._max_len and len(self._cache.keys()) > self._max_len:
+            sorted_entries = sorted(
+                self._cache.items(),
+                key=lambda k, v: v.time,
+            )
+
+            for k, _ in sorted_entries[self._max_len:]:
+                self._cache.pop(k)
+
+    def __getitem__(self, key):
+        entry = self._cache[key]
+
+        if self._reset_expiry_on_get:
+            entry.time = self._clock.time_msec()
+
+        return entry.value
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def _prune_cache(self):
+        if not self._expiry_ms:
+            # zero expiry time means don't expire. This should never get called
+            # since we have this check in start too.
+            return
+        begin_length = len(self._cache)
+
+        now = self._clock.time_msec()
+
+        keys_to_delete = set()
+
+        for key, cache_entry in self._cache.items():
+            if now - cache_entry.time > self._expiry_ms:
+                keys_to_delete.add(key)
+
+        for k in keys_to_delete:
+            self._cache.pop(k)
+
+        logger.debug(
+            "[%s] _prune_cache before: %d, after len: %d",
+            self._cache_name, begin_length, len(self._cache.keys())
+        )
+
+
+class _CacheEntry(object):
+    def __init__(self, time, value):
+        self.time = time
+        self.value = value
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
new file mode 100644
index 0000000000..4e82232796
--- /dev/null
+++ b/synapse/util/retryutils.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import CodeMessageException
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class NotRetryingDestination(Exception):
+    def __init__(self, retry_last_ts, retry_interval, destination):
+        msg = "Not retrying server %s." % (destination,)
+        super(NotRetryingDestination, self).__init__(msg)
+
+        self.retry_last_ts = retry_last_ts
+        self.retry_interval = retry_interval
+        self.destination = destination
+
+
+@defer.inlineCallbacks
+def get_retry_limiter(destination, clock, store, **kwargs):
+    """For a given destination check if we have previously failed to
+    send a request there and are waiting before retrying the destination.
+    If we are not ready to retry the destination, this will raise a
+    NotRetryingDestination exception. Otherwise, will return a Context Manager
+    that will mark the destination as down if an exception is thrown (excluding
+    CodeMessageException with code < 500)
+
+    Example usage:
+
+        try:
+            limiter = yield get_retry_limiter(destination, clock, store)
+            with limiter:
+                response = yield do_request()
+        except NotRetryingDestination:
+            # We aren't ready to retry that destination.
+            raise
+    """
+    retry_last_ts, retry_interval = (0, 0)
+
+    retry_timings = yield store.get_destination_retry_timings(
+        destination
+    )
+
+    if retry_timings:
+        retry_last_ts, retry_interval = (
+            retry_timings.retry_last_ts, retry_timings.retry_interval
+        )
+
+        now = int(clock.time_msec())
+
+        if retry_last_ts + retry_interval > now:
+            raise NotRetryingDestination(
+                retry_last_ts=retry_last_ts,
+                retry_interval=retry_interval,
+                destination=destination,
+            )
+
+    defer.returnValue(
+        RetryDestinationLimiter(
+            destination,
+            clock,
+            store,
+            retry_interval,
+            **kwargs
+        )
+    )
+
+
+class RetryDestinationLimiter(object):
+    def __init__(self, destination, clock, store, retry_interval,
+                 min_retry_interval=5000, max_retry_interval=60 * 60 * 1000,
+                 multiplier_retry_interval=2,):
+        """Marks the destination as "down" if an exception is thrown in the
+        context, except for CodeMessageException with code < 500.
+
+        If no exception is raised, marks the destination as "up".
+
+        Args:
+            destination (str)
+            clock (Clock)
+            store (DataStore)
+            retry_interval (int): The next retry interval taken from the
+                database in milliseconds, or zero if the last request was
+                successful.
+            min_retry_interval (int): The minimum retry interval to use after
+                a failed request, in milliseconds.
+            max_retry_interval (int): The maximum retry interval to use after
+                a failed request, in milliseconds.
+            multiplier_retry_interval (int): The multiplier to use to increase
+                the retry interval after a failed request.
+        """
+        self.clock = clock
+        self.store = store
+        self.destination = destination
+
+        self.retry_interval = retry_interval
+        self.min_retry_interval = min_retry_interval
+        self.max_retry_interval = max_retry_interval
+        self.multiplier_retry_interval = multiplier_retry_interval
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        def err(failure):
+            logger.exception(
+                "Failed to store set_destination_retry_timings",
+                failure.value
+            )
+
+        valid_err_code = False
+        if exc_type is CodeMessageException:
+            valid_err_code = 0 <= exc_val.code < 500
+
+        if exc_type is None or valid_err_code:
+            # We connected successfully.
+            if not self.retry_interval:
+                return
+
+            retry_last_ts = 0
+            self.retry_interval = 0
+        else:
+            # We couldn't connect.
+            if self.retry_interval:
+                self.retry_interval *= self.multiplier_retry_interval
+
+                if self.retry_interval >= self.max_retry_interval:
+                    self.retry_interval = self.max_retry_interval
+            else:
+                self.retry_interval = self.min_retry_interval
+
+            retry_last_ts = int(self.clock.time_msec())
+
+        self.store.set_destination_retry_timings(
+            self.destination, retry_last_ts, self.retry_interval
+        ).addErrback(err)