diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/__init__.py | 61 | ||||
-rw-r--r-- | synapse/util/expiringcache.py | 115 | ||||
-rw-r--r-- | synapse/util/frozenutils.py | 8 | ||||
-rw-r--r-- | synapse/util/lrucache.py | 117 | ||||
-rw-r--r-- | synapse/util/ratelimitutils.py | 216 | ||||
-rw-r--r-- | synapse/util/retryutils.py | 153 |
6 files changed, 667 insertions, 3 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 4e837a918e..79109d0b19 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,9 +15,12 @@ from synapse.util.logcontext import LoggingContext -from twisted.internet import reactor +from twisted.internet import defer, reactor, task import time +import logging + +logger = logging.getLogger(__name__) class Clock(object): @@ -35,6 +38,14 @@ class Clock(object): """Returns the current system time in miliseconds since epoch.""" return self.time() * 1000 + def looping_call(self, f, msec): + l = task.LoopingCall(f) + l.start(msec/1000.0, now=False) + return l + + def stop_looping_call(self, loop): + loop.stop() + def call_later(self, delay, callback): current_context = LoggingContext.current_context() @@ -45,3 +56,51 @@ class Clock(object): def cancel_call_later(self, timer): timer.cancel() + + def time_bound_deferred(self, given_deferred, time_out): + if given_deferred.called: + return given_deferred + + ret_deferred = defer.Deferred() + + def timed_out_fn(): + try: + ret_deferred.errback(RuntimeError("Timed out")) + except: + pass + + try: + given_deferred.cancel() + except: + pass + + timer = None + + def cancel(res): + try: + self.cancel_call_later(timer) + except: + pass + return res + + ret_deferred.addBoth(cancel) + + def sucess(res): + try: + ret_deferred.callback(res) + except: + pass + + return res + + def err(res): + try: + ret_deferred.errback(res) + except: + pass + + given_deferred.addCallbacks(callback=sucess, errback=err) + + timer = self.call_later(time_out, timed_out_fn) + + return ret_deferred diff --git a/synapse/util/expiringcache.py b/synapse/util/expiringcache.py new file mode 100644 index 0000000000..1c7859297a --- /dev/null +++ b/synapse/util/expiringcache.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +logger = logging.getLogger(__name__) + + +class ExpiringCache(object): + def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, + reset_expiry_on_get=False): + """ + Args: + cache_name (str): Name of this cache, used for logging. + clock (Clock) + max_len (int): Max size of dict. If the dict grows larger than this + then the oldest items get automatically evicted. Default is 0, + which indicates there is no max limit. + expiry_ms (int): How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + reset_expiry_on_get (bool): If true, will reset the expiry time for + an item on access. Defaults to False. + + """ + self._cache_name = cache_name + + self._clock = clock + + self._max_len = max_len + self._expiry_ms = expiry_ms + + self._reset_expiry_on_get = reset_expiry_on_get + + self._cache = {} + + def start(self): + if not self._expiry_ms: + # Don't bother starting the loop if things never expire + return + + def f(): + self._prune_cache() + + self._clock.looping_call(f, self._expiry_ms/2) + + def __setitem__(self, key, value): + now = self._clock.time_msec() + self._cache[key] = _CacheEntry(now, value) + + # Evict if there are now too many items + if self._max_len and len(self._cache.keys()) > self._max_len: + sorted_entries = sorted( + self._cache.items(), + key=lambda k, v: v.time, + ) + + for k, _ in sorted_entries[self._max_len:]: + self._cache.pop(k) + + def __getitem__(self, key): + entry = self._cache[key] + + if self._reset_expiry_on_get: + entry.time = self._clock.time_msec() + + return entry.value + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def _prune_cache(self): + if not self._expiry_ms: + # zero expiry time means don't expire. This should never get called + # since we have this check in start too. + return + begin_length = len(self._cache) + + now = self._clock.time_msec() + + keys_to_delete = set() + + for key, cache_entry in self._cache.items(): + if now - cache_entry.time > self._expiry_ms: + keys_to_delete.add(key) + + for k in keys_to_delete: + self._cache.pop(k) + + logger.debug( + "[%s] _prune_cache before: %d, after len: %d", + self._cache_name, begin_length, len(self._cache.keys()) + ) + + +class _CacheEntry(object): + def __init__(self, time, value): + self.time = time + self.value = value diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index a13a2015e4..9e10d37aec 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -21,6 +21,9 @@ def freeze(o): if t is dict: return frozendict({k: freeze(v) for k, v in o.items()}) + if t is frozendict: + return o + if t is str or t is unicode: return o @@ -33,10 +36,11 @@ def freeze(o): def unfreeze(o): - if isinstance(o, frozendict) or isinstance(o, dict): + t = type(o) + if t is dict or t is frozendict: return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, basestring): + if t is str or t is unicode: return o try: diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py new file mode 100644 index 0000000000..f115f50e50 --- /dev/null +++ b/synapse/util/lrucache.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class LruCache(object): + """Least-recently-used cache.""" + # TODO(mjark) Add hit/miss counters + # TODO(mjark) Add mutex for linked list for thread safety. + def __init__(self, max_size): + cache = {} + list_root = [] + list_root[:] = [list_root, list_root, None, None] + + PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 + + def add_node(key, value): + prev_node = list_root + next_node = prev_node[NEXT] + node = [prev_node, next_node, key, value] + prev_node[NEXT] = node + next_node[PREV] = node + cache[key] = node + + def move_node_to_front(node): + prev_node = node[PREV] + next_node = node[NEXT] + prev_node[NEXT] = next_node + next_node[PREV] = prev_node + prev_node = list_root + next_node = prev_node[NEXT] + node[PREV] = prev_node + node[NEXT] = next_node + prev_node[NEXT] = node + next_node[PREV] = node + + def delete_node(node): + prev_node = node[PREV] + next_node = node[NEXT] + prev_node[NEXT] = next_node + next_node[PREV] = prev_node + cache.pop(node[KEY], None) + + def cache_get(key, default=None): + node = cache.get(key, None) + if node is not None: + move_node_to_front(node) + return node[VALUE] + else: + return default + + def cache_set(key, value): + node = cache.get(key, None) + if node is not None: + move_node_to_front(node) + node[VALUE] = value + else: + add_node(key, value) + if len(cache) > max_size: + delete_node(list_root[PREV]) + + def cache_set_default(key, value): + node = cache.get(key, None) + if node is not None: + return node[VALUE] + else: + add_node(key, value) + if len(cache) > max_size: + delete_node(list_root[PREV]) + return value + + def cache_pop(key, default=None): + node = cache.get(key, None) + if node: + delete_node(node) + return node[VALUE] + else: + return default + + def cache_len(): + return len(cache) + + self.sentinel = object() + self.get = cache_get + self.set = cache_set + self.setdefault = cache_set_default + self.pop = cache_pop + self.len = cache_len + + def __getitem__(self, key): + result = self.get(key, self.sentinel) + if result is self.sentinel: + raise KeyError() + else: + return result + + def __setitem__(self, key, value): + self.set(key, value) + + def __delitem__(self, key, value): + result = self.pop(key, self.sentinel) + if result is self.sentinel: + raise KeyError() + + def __len__(self): + return self.len() diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py new file mode 100644 index 0000000000..d4457af950 --- /dev/null +++ b/synapse/util/ratelimitutils.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import LimitExceededError + +from synapse.util.async import sleep + +import collections +import contextlib +import logging + + +logger = logging.getLogger(__name__) + + +class FederationRateLimiter(object): + def __init__(self, clock, window_size, sleep_limit, sleep_msec, + reject_limit, concurrent_requests): + """ + Args: + clock (Clock) + window_size (int): The window size in milliseconds. + sleep_limit (int): The number of requests received in the last + `window_size` milliseconds before we artificially start + delaying processing of requests. + sleep_msec (int): The number of milliseconds to delay processing + of incoming requests by. + reject_limit (int): The maximum number of requests that are can be + queued for processing before we start rejecting requests with + a 429 Too Many Requests response. + concurrent_requests (int): The number of concurrent requests to + process. + """ + self.clock = clock + + self.window_size = window_size + self.sleep_limit = sleep_limit + self.sleep_msec = sleep_msec + self.reject_limit = reject_limit + self.concurrent_requests = concurrent_requests + + self.ratelimiters = {} + + def ratelimit(self, host): + """Used to ratelimit an incoming request from given host + + Example usage: + + with rate_limiter.ratelimit(origin) as wait_deferred: + yield wait_deferred + # Handle request ... + + Args: + host (str): Origin of incoming request. + + Returns: + _PerHostRatelimiter + """ + return self.ratelimiters.setdefault( + host, + _PerHostRatelimiter( + clock=self.clock, + window_size=self.window_size, + sleep_limit=self.sleep_limit, + sleep_msec=self.sleep_msec, + reject_limit=self.reject_limit, + concurrent_requests=self.concurrent_requests, + ) + ).ratelimit() + + +class _PerHostRatelimiter(object): + def __init__(self, clock, window_size, sleep_limit, sleep_msec, + reject_limit, concurrent_requests): + self.clock = clock + + self.window_size = window_size + self.sleep_limit = sleep_limit + self.sleep_msec = sleep_msec + self.reject_limit = reject_limit + self.concurrent_requests = concurrent_requests + + self.sleeping_requests = set() + self.ready_request_queue = collections.OrderedDict() + self.current_processing = set() + self.request_times = [] + + def is_empty(self): + time_now = self.clock.time_msec() + self.request_times[:] = [ + r for r in self.request_times + if time_now - r < self.window_size + ] + + return not ( + self.ready_request_queue + or self.sleeping_requests + or self.current_processing + or self.request_times + ) + + @contextlib.contextmanager + def ratelimit(self): + # `contextlib.contextmanager` takes a generator and turns it into a + # context manager. The generator should only yield once with a value + # to be returned by manager. + # Exceptions will be reraised at the yield. + + request_id = object() + ret = self._on_enter(request_id) + try: + yield ret + finally: + self._on_exit(request_id) + + def _on_enter(self, request_id): + time_now = self.clock.time_msec() + self.request_times[:] = [ + r for r in self.request_times + if time_now - r < self.window_size + ] + + queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) + if queue_size > self.reject_limit: + raise LimitExceededError( + retry_after_ms=int( + self.window_size / self.sleep_limit + ), + ) + + self.request_times.append(time_now) + + def queue_request(): + if len(self.current_processing) > self.concurrent_requests: + logger.debug("Ratelimit [%s]: Queue req", id(request_id)) + queue_defer = defer.Deferred() + self.ready_request_queue[request_id] = queue_defer + return queue_defer + else: + return defer.succeed(None) + + logger.debug( + "Ratelimit [%s]: len(self.request_times)=%d", + id(request_id), len(self.request_times), + ) + + if len(self.request_times) > self.sleep_limit: + logger.debug( + "Ratelimit [%s]: sleeping req", + id(request_id), + ) + ret_defer = sleep(self.sleep_msec/1000.0) + + self.sleeping_requests.add(request_id) + + def on_wait_finished(_): + logger.debug( + "Ratelimit [%s]: Finished sleeping", + id(request_id), + ) + self.sleeping_requests.discard(request_id) + queue_defer = queue_request() + return queue_defer + + ret_defer.addBoth(on_wait_finished) + else: + ret_defer = queue_request() + + def on_start(r): + logger.debug( + "Ratelimit [%s]: Processing req", + id(request_id), + ) + self.current_processing.add(request_id) + return r + + def on_err(r): + self.current_processing.discard(request_id) + return r + + def on_both(r): + # Ensure that we've properly cleaned up. + self.sleeping_requests.discard(request_id) + self.ready_request_queue.pop(request_id, None) + return r + + ret_defer.addCallbacks(on_start, on_err) + ret_defer.addBoth(on_both) + return ret_defer + + def _on_exit(self, request_id): + logger.debug( + "Ratelimit [%s]: Processed req", + id(request_id), + ) + self.current_processing.discard(request_id) + try: + request_id, deferred = self.ready_request_queue.popitem() + self.current_processing.add(request_id) + deferred.callback(None) + except KeyError: + pass diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py new file mode 100644 index 0000000000..4e82232796 --- /dev/null +++ b/synapse/util/retryutils.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import CodeMessageException + +import logging + + +logger = logging.getLogger(__name__) + + +class NotRetryingDestination(Exception): + def __init__(self, retry_last_ts, retry_interval, destination): + msg = "Not retrying server %s." % (destination,) + super(NotRetryingDestination, self).__init__(msg) + + self.retry_last_ts = retry_last_ts + self.retry_interval = retry_interval + self.destination = destination + + +@defer.inlineCallbacks +def get_retry_limiter(destination, clock, store, **kwargs): + """For a given destination check if we have previously failed to + send a request there and are waiting before retrying the destination. + If we are not ready to retry the destination, this will raise a + NotRetryingDestination exception. Otherwise, will return a Context Manager + that will mark the destination as down if an exception is thrown (excluding + CodeMessageException with code < 500) + + Example usage: + + try: + limiter = yield get_retry_limiter(destination, clock, store) + with limiter: + response = yield do_request() + except NotRetryingDestination: + # We aren't ready to retry that destination. + raise + """ + retry_last_ts, retry_interval = (0, 0) + + retry_timings = yield store.get_destination_retry_timings( + destination + ) + + if retry_timings: + retry_last_ts, retry_interval = ( + retry_timings.retry_last_ts, retry_timings.retry_interval + ) + + now = int(clock.time_msec()) + + if retry_last_ts + retry_interval > now: + raise NotRetryingDestination( + retry_last_ts=retry_last_ts, + retry_interval=retry_interval, + destination=destination, + ) + + defer.returnValue( + RetryDestinationLimiter( + destination, + clock, + store, + retry_interval, + **kwargs + ) + ) + + +class RetryDestinationLimiter(object): + def __init__(self, destination, clock, store, retry_interval, + min_retry_interval=5000, max_retry_interval=60 * 60 * 1000, + multiplier_retry_interval=2,): + """Marks the destination as "down" if an exception is thrown in the + context, except for CodeMessageException with code < 500. + + If no exception is raised, marks the destination as "up". + + Args: + destination (str) + clock (Clock) + store (DataStore) + retry_interval (int): The next retry interval taken from the + database in milliseconds, or zero if the last request was + successful. + min_retry_interval (int): The minimum retry interval to use after + a failed request, in milliseconds. + max_retry_interval (int): The maximum retry interval to use after + a failed request, in milliseconds. + multiplier_retry_interval (int): The multiplier to use to increase + the retry interval after a failed request. + """ + self.clock = clock + self.store = store + self.destination = destination + + self.retry_interval = retry_interval + self.min_retry_interval = min_retry_interval + self.max_retry_interval = max_retry_interval + self.multiplier_retry_interval = multiplier_retry_interval + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + def err(failure): + logger.exception( + "Failed to store set_destination_retry_timings", + failure.value + ) + + valid_err_code = False + if exc_type is CodeMessageException: + valid_err_code = 0 <= exc_val.code < 500 + + if exc_type is None or valid_err_code: + # We connected successfully. + if not self.retry_interval: + return + + retry_last_ts = 0 + self.retry_interval = 0 + else: + # We couldn't connect. + if self.retry_interval: + self.retry_interval *= self.multiplier_retry_interval + + if self.retry_interval >= self.max_retry_interval: + self.retry_interval = self.max_retry_interval + else: + self.retry_interval = self.min_retry_interval + + retry_last_ts = int(self.clock.time_msec()) + + self.store.set_destination_retry_timings( + self.destination, retry_last_ts, self.retry_interval + ).addErrback(err) |