From 62c010283d543db0956066b42eb735b57c000a82 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 23 Jul 2015 16:03:38 +0100 Subject: Add federation support for end-to-end key requests --- synapse/federation/federation_client.py | 34 ++++++++++++++++ synapse/federation/federation_server.py | 37 +++++++++++++++++ synapse/federation/transport/client.py | 70 +++++++++++++++++++++++++++++++++ synapse/federation/transport/server.py | 20 ++++++++++ 4 files changed, 161 insertions(+) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb5..21a86a4c6d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -134,6 +134,40 @@ class FederationClient(FederationBase): destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail ) + @log_function + def query_client_keys(self, destination, content, retry_on_dns_fail=True): + """Query device keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_device_keys") + return self.transport_layer.query_client_keys( + destination, content, retry_on_dns_fail=retry_on_dns_fail + ) + + @log_function + def claim_client_keys(self, destination, content, retry_on_dns_fail=True): + """Claims one-time keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_one_time_keys") + return self.transport_layer.claim_client_keys( + destination, content, retry_on_dns_fail=retry_on_dns_fail + ) + @defer.inlineCallbacks @log_function def backfill(self, dest, context, limit, extremities): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index cd79e23f4b..c32908ac28 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature +import simplejson as json import logging @@ -312,6 +313,42 @@ class FederationServer(FederationBase): (200, send_content) ) + @defer.inlineCallbacks + @log_function + def on_query_client_keys(self, origin, content): + query = [] + for user_id, device_ids in content.get("device_keys", {}).items(): + if not device_ids: + query.append((user_id, None)) + else: + for device_id in device_ids: + query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(query) + json_result = {} + for user_id, device_keys in results.items(): + for device_id, json_bytes in device_keys.items(): + json_result.setdefault(user_id, {})[device_id] = json.loads( + json_bytes + ) + defer.returnValue({"device_keys": json_result}) + + @defer.inlineCallbacks + @log_function + def on_claim_client_keys(self, origin, content): + query = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + query.append((user_id, device_id, algorithm)) + results = yield self.store.claim_e2e_one_time_keys(query) + json_result = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + defer.returnValue({"one_time_keys": json_result}) + @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 610a4c3163..df5083dd22 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -222,6 +222,76 @@ class TransportLayerClient(object): defer.returnValue(content) + @defer.inlineCallbacks + @log_function + def query_client_keys(self, destination, query_content): + """Query the device keys for a list of user ids hosted on a remote + server. + + Request: + { + "device_keys": { + "": [""] + } } + + Response: + { + "device_keys": { + "": { + "": {...} + } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/client_keys/query" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + + @defer.inlineCallbacks + @log_function + def claim_client_keys(self, destination, query_content): + """Claim one-time keys for a list of devices hosted on a remote server. + + Request: + { + "one_time_keys": { + "": { + "": "" + } } } + + Response: + { + "device_keys": { + "": { + "": { + ":": "" + } } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the one-time keys. + """ + + path = PREFIX + "/client_keys/claim" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + @defer.inlineCallbacks @log_function def get_missing_events(self, destination, room_id, earliest_events, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bad93c6b2f..fb59383ecd 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet): defer.returnValue((200, content)) +class FederationClientKeysQueryServlet(BaseFederationServlet): + PATH = "/client_keys/query" + + @defer.inlineCallbacks + def on_POST(self, origin, content): + response = yield self.handler.on_client_key_query(origin, content) + defer.returnValue((200, response)) + + +class FederationClientKeysClaimServlet(BaseFederationServlet): + PATH = "/client_keys/claim" + + @defer.inlineCallbacks + def on_POST(self, origin, content): + response = yield self.handler.on_client_key_claim(origin, content) + defer.returnValue((200, response)) + + class FederationQueryAuthServlet(BaseFederationServlet): PATH = "/query_auth/([^/]*)/([^/]*)" @@ -373,4 +391,6 @@ SERVLET_CLASSES = ( FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationClientKeysClaimServlet, ) -- cgit 1.4.1 From 2da3b1e60bf7e9ae1d6714abcff0a0c224cadf28 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 24 Jul 2015 18:26:46 +0100 Subject: Get the end-to-end key federation working --- synapse/federation/federation_client.py | 12 ++++-------- synapse/federation/transport/client.py | 4 ++-- synapse/federation/transport/server.py | 12 ++++++------ synapse/rest/client/v2_alpha/keys.py | 10 +++++----- 4 files changed, 17 insertions(+), 21 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 21a86a4c6d..44e4d0755a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -135,7 +135,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content, retry_on_dns_fail=True): + def query_client_keys(self, destination, content): """Query device keys for a device hosted on a remote server. Args: @@ -147,12 +147,10 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys( - destination, content, retry_on_dns_fail=retry_on_dns_fail - ) + return self.transport_layer.query_client_keys(destination, content) @log_function - def claim_client_keys(self, destination, content, retry_on_dns_fail=True): + def claim_client_keys(self, destination, content): """Claims one-time keys for a device hosted on a remote server. Args: @@ -164,9 +162,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys( - destination, content, retry_on_dns_fail=retry_on_dns_fail - ) + return self.transport_layer.claim_client_keys(destination, content) @defer.inlineCallbacks @log_function diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index df5083dd22..ced703364b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -247,7 +247,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/client_keys/query" + path = PREFIX + "/user/keys/query" content = yield self.client.post_json( destination=destination, @@ -283,7 +283,7 @@ class TransportLayerClient(object): A dict containg the one-time keys. """ - path = PREFIX + "/client_keys/claim" + path = PREFIX + "/user/keys/claim" content = yield self.client.post_json( destination=destination, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fb59383ecd..36f250e1a3 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -326,20 +326,20 @@ class FederationInviteServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet): - PATH = "/client_keys/query" + PATH = "/user/keys/query" @defer.inlineCallbacks - def on_POST(self, origin, content): - response = yield self.handler.on_client_key_query(origin, content) + def on_POST(self, origin, content, query): + response = yield self.handler.on_query_client_keys(origin, content) defer.returnValue((200, response)) class FederationClientKeysClaimServlet(BaseFederationServlet): - PATH = "/client_keys/claim" + PATH = "/user/keys/claim" @defer.inlineCallbacks - def on_POST(self, origin, content): - response = yield self.handler.on_client_key_claim(origin, content) + def on_POST(self, origin, content, query): + response = yield self.handler.on_claim_client_keys(origin, content) defer.returnValue((200, response)) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 739a08ada8..718928eedd 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -202,7 +202,7 @@ class KeyQueryServlet(RestServlet): for device_id in device_ids: local_query.append((user_id, device_id)) else: - remote_queries.set_default(user.domain, {})[user_id] = list( + remote_queries.setdefault(user.domain, {})[user_id] = list( device_ids ) results = yield self.store.get_e2e_device_keys(local_query) @@ -218,7 +218,7 @@ class KeyQueryServlet(RestServlet): remote_result = yield self.federation.query_client_keys( destination, {"device_keys": device_keys} ) - for user_id, keys in remote_result.items(): + for user_id, keys in remote_result["device_keys"].items(): if user_id in device_keys: json_result[user_id] = keys defer.returnValue((200, {"device_keys": json_result})) @@ -286,7 +286,7 @@ class OneTimeKeyServlet(RestServlet): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: - remote_queries.set_default(user.domain, {})[user_id] = ( + remote_queries.setdefault(user.domain, {})[user_id] = ( device_keys ) results = yield self.store.claim_e2e_one_time_keys(local_query) @@ -300,10 +300,10 @@ class OneTimeKeyServlet(RestServlet): } for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( + remote_result = yield self.federation.claim_client_keys( destination, {"one_time_keys": device_keys} ) - for user_id, keys in remote_result.items(): + for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys -- cgit 1.4.1 From 2df8dd9b37f26e3ad0d3647a1e78804a85d48c0c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2015 17:59:32 +0100 Subject: Move all the caches into their own package, synapse.util.caches --- synapse/federation/federation_client.py | 2 +- synapse/state.py | 2 +- synapse/storage/_base.py | 335 +---------------------------- synapse/storage/directory.py | 3 +- synapse/storage/event_federation.py | 3 +- synapse/storage/keys.py | 3 +- synapse/storage/presence.py | 3 +- synapse/storage/push_rule.py | 3 +- synapse/storage/receipts.py | 3 +- synapse/storage/registration.py | 3 +- synapse/storage/room.py | 3 +- synapse/storage/roommember.py | 3 +- synapse/storage/state.py | 5 +- synapse/storage/stream.py | 3 +- synapse/storage/transactions.py | 3 +- synapse/util/caches/__init__.py | 14 ++ synapse/util/caches/descriptors.py | 359 ++++++++++++++++++++++++++++++++ synapse/util/caches/dictionary_cache.py | 109 ++++++++++ synapse/util/caches/expiringcache.py | 115 ++++++++++ synapse/util/caches/lrucache.py | 149 +++++++++++++ synapse/util/dictionary_cache.py | 109 ---------- synapse/util/expiringcache.py | 115 ---------- synapse/util/lrucache.py | 149 ------------- tests/storage/test__base.py | 2 +- tests/util/test_dict_cache.py | 2 +- tests/util/test_lrucache.py | 4 +- 26 files changed, 780 insertions(+), 724 deletions(-) create mode 100644 synapse/util/caches/__init__.py create mode 100644 synapse/util/caches/descriptors.py create mode 100644 synapse/util/caches/dictionary_cache.py create mode 100644 synapse/util/caches/expiringcache.py create mode 100644 synapse/util/caches/lrucache.py delete mode 100644 synapse/util/dictionary_cache.py delete mode 100644 synapse/util/expiringcache.py delete mode 100644 synapse/util/lrucache.py (limited to 'synapse/federation') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb5..58a6d6a0ed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -23,7 +23,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent import synapse.metrics diff --git a/synapse/state.py b/synapse/state.py index b5e5d7bbda..1fe4d066bd 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor -from synapse.util.expiringcache import ExpiringCache +from synapse.util.caches.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e5441aafb2..1444767a52 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,27 +15,22 @@ import logging from synapse.api.errors import StoreError -from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext -from synapse.util.lrucache import LruCache -from synapse.util.dictionary_cache import DictionaryCache +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple, OrderedDict +from collections import namedtuple -import functools -import inspect import sys import time import threading -DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -51,330 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - -_CacheSentinel = object() - - -class Cache(object): - - def __init__(self, name, max_entries=1000, keylen=1, lru=True): - if lru: - self.cache = LruCache(max_size=max_entries) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries - - self.name = name - self.keylen = keylen - self.sequence = 0 - self.thread = None - caches_by_name[name] = self.cache - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, default=_CacheSentinel): - val = self.cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) - return val - - cache_counter.inc_misses(self.name) - - if default is _CacheSentinel: - raise KeyError() - else: - return default - - def update(self, sequence, key, value): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(key, value) - - def prefill(self, key, value): - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[key] = value - - def invalidate(self, key): - self.check_thread() - if not isinstance(key, tuple): - raise ValueError("keyargs must be a tuple.") - - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(key, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - -class CacheDescriptor(object): - """ A method decorator that applies a memoizing cache around the function. - - This caches deferreds, rather than the results themselves. Deferreds that - fail are removed from the cache. - - The function is presumed to take zero or more arguments, which are used in - a tuple as the key for the cache. Hits are served directly from the cache; - misses use the function body to generate the value. - - The wrapped function has an additional member, a callable called - "invalidate". This can be used to remove individual entries from the cache. - - The wrapped function has another additional callable, called "prefill", - which can be used to insert values into the cache specifically, without - calling the calculation function. - """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, - inlineCallbacks=False): - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.max_entries = max_entries - self.num_args = num_args - self.lru = lru - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - self.cache = Cache( - name=self.orig.__name__, - max_entries=self.max_entries, - keylen=self.num_args, - lru=self.lru, - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) - try: - cached_result_d = self.cache.get(cache_key) - - observer = cached_result_d.observe() - if DEBUG_CACHES: - @defer.inlineCallbacks - def check_result(cached_result): - actual_result = yield self.function_to_call(obj, *args, **kwargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, cache_key, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - observer.addCallback(check_result) - - return observer - except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence - - ret = defer.maybeDeferred( - self.function_to_call, - obj, *args, **kwargs - ) - - def onErr(f): - self.cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) - - return ret.observe() - - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.prefill = self.cache.prefill - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -class CacheListDescriptor(object): - """Wraps an existing cache to support bulk fetching of keys. - - Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. - """ - - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): - """ - Args: - orig (function) - cache (Cache) - list_name (str): Name of the argument which is the bulk lookup list - num_args (int) - inlineCallbacks (bool): Whether orig is a generator that should - be wrapped by defer.inlineCallbacks - """ - self.orig = orig - - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.num_args = num_args - self.list_name = list_name - - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] - self.list_pos = self.arg_names.index(self.list_name) - - self.cache = cache - - self.sentinel = object() - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - - if self.list_name not in self.arg_names: - raise Exception( - "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) - ) - - def __get__(self, obj, objtype=None): - - @functools.wraps(self.orig) - def wrapped(*args, **kwargs): - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] - list_args = arg_dict[self.list_name] - - # cached is a dict arg -> deferred, where deferred results in a - # 2-tuple (`arg`, `result`) - cached = {} - missing = [] - for arg in list_args: - key = list(keyargs) - key[self.list_pos] = arg - - try: - res = self.cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res - except KeyError: - missing.append(arg) - - if missing: - sequence = self.cache.sequence - args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing - - ret_d = defer.maybeDeferred( - self.function_to_call, - **args_to_call - ) - - ret_d = ObservableDeferred(ret_d) - - # We need to create deferreds for each arg in the list so that - # we can insert the new deferred into the cache. - for arg in missing: - observer = ret_d.observe() - observer.addCallback(lambda r, arg: r[arg], arg) - - observer = ObservableDeferred(observer) - - key = list(keyargs) - key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) - - def invalidate(f, key): - self.cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) - - res = observer.observe() - res.addCallback(lambda r, arg: (arg, r), arg) - - cached[arg] = res - - return defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -def cached(max_entries=1000, num_args=1, lru=True): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru - ) - - -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru, - inlineCallbacks=True, - ) - - -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): - return lambda orig: CacheListDescriptor( - orig, - cache=cache, - list_name=list_name, - num_args=num_args, - inlineCallbacks=inlineCallbacks, - ) - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index f3947bbe89..d92028ea43 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.errors import SynapseError diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 910b6598a7..25cc84eb95 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -15,7 +15,8 @@ from twisted.internet import defer -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from syutil.base64util import encode_base64 import logging diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 49b8e37cfd..ffd6daa880 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, cachedInlineCallbacks +from _base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 576cf670cc..4f91a2b87c 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from twisted.internet import defer diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9b88ca7b39..5305b7e122 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer import logging diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b79d6683ca..cac1a5657e 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4eaa088b36..aa446f94c6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached class RegistrationStore(SQLBaseStore): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dd5bc2c8fb..5e07b7e0e5 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks import collections import logging diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9f14f38f24..8eee2dfbcc 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -17,7 +17,8 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from synapse.api.constants import Membership from synapse.types import UserID diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ea5fa9de7b..79c3b82d9f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached, cachedInlineCallbacks, cachedList +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import ( + cached, cachedInlineCallbacks, cachedList +) from twisted.internet import defer diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b59fe81004..d7fe423f5a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,7 +35,8 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore, cachedInlineCallbacks +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 624da4a9dc..c8c7e6591a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, cached +from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached from collections import namedtuple diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py new file mode 100644 index 0000000000..1a84d94cd9 --- /dev/null +++ b/synapse/util/caches/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py new file mode 100644 index 0000000000..82dd09cf5e --- /dev/null +++ b/synapse/util/caches/descriptors.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError +from synapse.util.caches.lrucache import LruCache +import synapse.metrics + +from twisted.internet import defer + +from collections import OrderedDict + +import functools +import inspect +import threading + +logger = logging.getLogger(__name__) + + +DEBUG_CACHES = False + +metrics = synapse.metrics.get_metrics_for("synapse.util.caches") + +caches_by_name = {} +cache_counter = metrics.register_cache( + "cache", + lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, + labels=["name"], +) + + +_CacheSentinel = object() + + +class Cache(object): + + def __init__(self, name, max_entries=1000, keylen=1, lru=True): + if lru: + self.cache = LruCache(max_size=max_entries) + self.max_entries = None + else: + self.cache = OrderedDict() + self.max_entries = max_entries + + self.name = name + self.keylen = keylen + self.sequence = 0 + self.thread = None + caches_by_name[name] = self.cache + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, default=_CacheSentinel): + val = self.cache.get(key, _CacheSentinel) + if val is not _CacheSentinel: + cache_counter.inc_hits(self.name) + return val + + cache_counter.inc_misses(self.name) + + if default is _CacheSentinel: + raise KeyError() + else: + return default + + def update(self, sequence, key, value): + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + self.prefill(key, value) + + def prefill(self, key, value): + if self.max_entries is not None: + while len(self.cache) >= self.max_entries: + self.cache.popitem(last=False) + + self.cache[key] = value + + def invalidate(self, key): + self.check_thread() + if not isinstance(key, tuple): + raise ValueError("keyargs must be a tuple.") + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + +class CacheDescriptor(object): + """ A method decorator that applies a memoizing cache around the function. + + This caches deferreds, rather than the results themselves. Deferreds that + fail are removed from the cache. + + The function is presumed to take zero or more arguments, which are used in + a tuple as the key for the cache. Hits are served directly from the cache; + misses use the function body to generate the value. + + The wrapped function has an additional member, a callable called + "invalidate". This can be used to remove individual entries from the cache. + + The wrapped function has another additional callable, called "prefill", + which can be used to insert values into the cache specifically, without + calling the calculation function. + """ + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + inlineCallbacks=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.max_entries = max_entries + self.num_args = num_args + self.lru = lru + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + self.cache = Cache( + name=self.orig.__name__, + max_entries=self.max_entries, + keylen=self.num_args, + lru=self.lru, + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + try: + cached_result_d = self.cache.get(cache_key) + + observer = cached_result_d.observe() + if DEBUG_CACHES: + @defer.inlineCallbacks + def check_result(cached_result): + actual_result = yield self.function_to_call(obj, *args, **kwargs) + if actual_result != cached_result: + logger.error( + "Stale cache entry %s%r: cached: %r, actual %r", + self.orig.__name__, cache_key, + cached_result, actual_result, + ) + raise ValueError("Stale cache entry") + defer.returnValue(cached_result) + observer.addCallback(check_result) + + return observer + except KeyError: + # Get the sequence number of the cache before reading from the + # database so that we can tell if the cache is invalidated + # while the SELECT is executing (SYN-369) + sequence = self.cache.sequence + + ret = defer.maybeDeferred( + self.function_to_call, + obj, *args, **kwargs + ) + + def onErr(f): + self.cache.invalidate(cache_key) + return f + + ret.addErrback(onErr) + + ret = ObservableDeferred(ret, consumeErrors=True) + self.cache.update(sequence, cache_key, ret) + + return ret.observe() + + wrapped.invalidate = self.cache.invalidate + wrapped.invalidate_all = self.cache.invalidate_all + wrapped.prefill = self.cache.prefill + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class CacheListDescriptor(object): + """Wraps an existing cache to support bulk fetching of keys. + + Given a list of keys it looks in the cache to find any hits, then passes + the list of missing keys to the wrapped fucntion. + """ + + def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + """ + Args: + orig (function) + cache (Cache) + list_name (str): Name of the argument which is the bulk lookup list + num_args (int) + inlineCallbacks (bool): Whether orig is a generator that should + be wrapped by defer.inlineCallbacks + """ + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + self.num_args = num_args + self.list_name = list_name + + self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.list_pos = self.arg_names.index(self.list_name) + + self.cache = cache + + self.sentinel = object() + + if len(self.arg_names) < self.num_args: + raise Exception( + "Not enough explicit positional arguments to key off of for %r." + " (@cached cannot key off of *args or **kwars)" + % (orig.__name__,) + ) + + if self.list_name not in self.arg_names: + raise Exception( + "Couldn't see arguments %r for %r." + % (self.list_name, cache.name,) + ) + + def __get__(self, obj, objtype=None): + + @functools.wraps(self.orig) + def wrapped(*args, **kwargs): + arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) + keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] + list_args = arg_dict[self.list_name] + + # cached is a dict arg -> deferred, where deferred results in a + # 2-tuple (`arg`, `result`) + cached = {} + missing = [] + for arg in list_args: + key = list(keyargs) + key[self.list_pos] = arg + + try: + res = self.cache.get(tuple(key)).observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached[arg] = res + except KeyError: + missing.append(arg) + + if missing: + sequence = self.cache.sequence + args_to_call = dict(arg_dict) + args_to_call[self.list_name] = missing + + ret_d = defer.maybeDeferred( + self.function_to_call, + **args_to_call + ) + + ret_d = ObservableDeferred(ret_d) + + # We need to create deferreds for each arg in the list so that + # we can insert the new deferred into the cache. + for arg in missing: + observer = ret_d.observe() + observer.addCallback(lambda r, arg: r[arg], arg) + + observer = ObservableDeferred(observer) + + key = list(keyargs) + key[self.list_pos] = arg + self.cache.update(sequence, tuple(key), observer) + + def invalidate(f, key): + self.cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) + + res = observer.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + + cached[arg] = res + + return defer.gatherResults( + cached.values(), + consumeErrors=True, + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +def cached(max_entries=1000, num_args=1, lru=True): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru + ) + + +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): + return lambda orig: CacheDescriptor( + orig, + max_entries=max_entries, + num_args=num_args, + lru=lru, + inlineCallbacks=True, + ) + + +def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): + return lambda orig: CacheListDescriptor( + orig, + cache=cache, + list_name=list_name, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + ) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py new file mode 100644 index 0000000000..26d464f4f7 --- /dev/null +++ b/synapse/util/caches/dictionary_cache.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches.lrucache import LruCache +from collections import namedtuple +import threading +import logging + + +logger = logging.getLogger(__name__) + + +DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) + + +class DictionaryCache(object): + """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. + fetching a subset of dictionary keys for a particular key. + """ + + def __init__(self, name, max_entries=1000): + self.cache = LruCache(max_size=max_entries) + + self.name = name + self.sequence = 0 + self.thread = None + # caches_by_name[name] = self.cache + + class Sentinel(object): + __slots__ = [] + + self.sentinel = Sentinel() + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get(self, key, dict_keys=None): + try: + entry = self.cache.get(key, self.sentinel) + if entry is not self.sentinel: + # cache_counter.inc_hits(self.name) + + if dict_keys is None: + return DictionaryEntry(entry.full, dict(entry.value)) + else: + return DictionaryEntry(entry.full, { + k: entry.value[k] + for k in dict_keys + if k in entry.value + }) + + # cache_counter.inc_misses(self.name) + return DictionaryEntry(False, {}) + except: + logger.exception("get failed") + raise + + def invalidate(self, key): + self.check_thread() + + # Increment the sequence number so that any SELECT statements that + # raced with the INSERT don't update the cache (SYN-369) + self.sequence += 1 + self.cache.pop(key, None) + + def invalidate_all(self): + self.check_thread() + self.sequence += 1 + self.cache.clear() + + def update(self, sequence, key, value, full=False): + try: + self.check_thread() + if self.sequence == sequence: + # Only update the cache if the caches sequence number matches the + # number that the cache had before the SELECT was started (SYN-369) + if full: + self._insert(key, value) + else: + self._update_or_insert(key, value) + except: + logger.exception("update failed") + raise + + def _update_or_insert(self, key, value): + entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + entry.value.update(value) + + def _insert(self, key, value): + self.cache[key] = DictionaryEntry(True, value) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py new file mode 100644 index 0000000000..06d1eea01b --- /dev/null +++ b/synapse/util/caches/expiringcache.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +logger = logging.getLogger(__name__) + + +class ExpiringCache(object): + def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, + reset_expiry_on_get=False): + """ + Args: + cache_name (str): Name of this cache, used for logging. + clock (Clock) + max_len (int): Max size of dict. If the dict grows larger than this + then the oldest items get automatically evicted. Default is 0, + which indicates there is no max limit. + expiry_ms (int): How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + reset_expiry_on_get (bool): If true, will reset the expiry time for + an item on access. Defaults to False. + + """ + self._cache_name = cache_name + + self._clock = clock + + self._max_len = max_len + self._expiry_ms = expiry_ms + + self._reset_expiry_on_get = reset_expiry_on_get + + self._cache = {} + + def start(self): + if not self._expiry_ms: + # Don't bother starting the loop if things never expire + return + + def f(): + self._prune_cache() + + self._clock.looping_call(f, self._expiry_ms/2) + + def __setitem__(self, key, value): + now = self._clock.time_msec() + self._cache[key] = _CacheEntry(now, value) + + # Evict if there are now too many items + if self._max_len and len(self._cache.keys()) > self._max_len: + sorted_entries = sorted( + self._cache.items(), + key=lambda (k, v): v.time, + ) + + for k, _ in sorted_entries[self._max_len:]: + self._cache.pop(k) + + def __getitem__(self, key): + entry = self._cache[key] + + if self._reset_expiry_on_get: + entry.time = self._clock.time_msec() + + return entry.value + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def _prune_cache(self): + if not self._expiry_ms: + # zero expiry time means don't expire. This should never get called + # since we have this check in start too. + return + begin_length = len(self._cache) + + now = self._clock.time_msec() + + keys_to_delete = set() + + for key, cache_entry in self._cache.items(): + if now - cache_entry.time > self._expiry_ms: + keys_to_delete.add(key) + + for k in keys_to_delete: + self._cache.pop(k) + + logger.debug( + "[%s] _prune_cache before: %d, after len: %d", + self._cache_name, begin_length, len(self._cache.keys()) + ) + + +class _CacheEntry(object): + def __init__(self, time, value): + self.time = time + self.value = value diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py new file mode 100644 index 0000000000..cacd7e45fa --- /dev/null +++ b/synapse/util/caches/lrucache.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import wraps +import threading + + +class LruCache(object): + """Least-recently-used cache.""" + def __init__(self, max_size): + cache = {} + list_root = [] + list_root[:] = [list_root, list_root, None, None] + + PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 + + lock = threading.Lock() + + def synchronized(f): + @wraps(f) + def inner(*args, **kwargs): + with lock: + return f(*args, **kwargs) + + return inner + + def add_node(key, value): + prev_node = list_root + next_node = prev_node[NEXT] + node = [prev_node, next_node, key, value] + prev_node[NEXT] = node + next_node[PREV] = node + cache[key] = node + + def move_node_to_front(node): + prev_node = node[PREV] + next_node = node[NEXT] + prev_node[NEXT] = next_node + next_node[PREV] = prev_node + prev_node = list_root + next_node = prev_node[NEXT] + node[PREV] = prev_node + node[NEXT] = next_node + prev_node[NEXT] = node + next_node[PREV] = node + + def delete_node(node): + prev_node = node[PREV] + next_node = node[NEXT] + prev_node[NEXT] = next_node + next_node[PREV] = prev_node + cache.pop(node[KEY], None) + + @synchronized + def cache_get(key, default=None): + node = cache.get(key, None) + if node is not None: + move_node_to_front(node) + return node[VALUE] + else: + return default + + @synchronized + def cache_set(key, value): + node = cache.get(key, None) + if node is not None: + move_node_to_front(node) + node[VALUE] = value + else: + add_node(key, value) + if len(cache) > max_size: + delete_node(list_root[PREV]) + + @synchronized + def cache_set_default(key, value): + node = cache.get(key, None) + if node is not None: + return node[VALUE] + else: + add_node(key, value) + if len(cache) > max_size: + delete_node(list_root[PREV]) + return value + + @synchronized + def cache_pop(key, default=None): + node = cache.get(key, None) + if node: + delete_node(node) + return node[VALUE] + else: + return default + + @synchronized + def cache_clear(): + list_root[NEXT] = list_root + list_root[PREV] = list_root + cache.clear() + + @synchronized + def cache_len(): + return len(cache) + + @synchronized + def cache_contains(key): + return key in cache + + self.sentinel = object() + self.get = cache_get + self.set = cache_set + self.setdefault = cache_set_default + self.pop = cache_pop + self.len = cache_len + self.contains = cache_contains + self.clear = cache_clear + + def __getitem__(self, key): + result = self.get(key, self.sentinel) + if result is self.sentinel: + raise KeyError() + else: + return result + + def __setitem__(self, key, value): + self.set(key, value) + + def __delitem__(self, key, value): + result = self.pop(key, self.sentinel) + if result is self.sentinel: + raise KeyError() + + def __len__(self): + return self.len() + + def __contains__(self, key): + return self.contains(key) diff --git a/synapse/util/dictionary_cache.py b/synapse/util/dictionary_cache.py deleted file mode 100644 index c7564cdf0d..0000000000 --- a/synapse/util/dictionary_cache.py +++ /dev/null @@ -1,109 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.lrucache import LruCache -from collections import namedtuple -import threading -import logging - - -logger = logging.getLogger(__name__) - - -DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) - - -class DictionaryCache(object): - """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. - fetching a subset of dictionary keys for a particular key. - """ - - def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) - - self.name = name - self.sequence = 0 - self.thread = None - # caches_by_name[name] = self.cache - - class Sentinel(object): - __slots__ = [] - - self.sentinel = Sentinel() - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, dict_keys=None): - try: - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: - # cache_counter.inc_hits(self.name) - - if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) - else: - return DictionaryEntry(entry.full, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) - - # cache_counter.inc_misses(self.name) - return DictionaryEntry(False, {}) - except: - logger.exception("get failed") - raise - - def invalidate(self, key): - self.check_thread() - - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(key, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - def update(self, sequence, key, value, full=False): - try: - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - if full: - self._insert(key, value) - else: - self._update_or_insert(key, value) - except: - logger.exception("update failed") - raise - - def _update_or_insert(self, key, value): - entry = self.cache.setdefault(key, DictionaryEntry(False, {})) - entry.value.update(value) - - def _insert(self, key, value): - self.cache[key] = DictionaryEntry(True, value) diff --git a/synapse/util/expiringcache.py b/synapse/util/expiringcache.py deleted file mode 100644 index 06d1eea01b..0000000000 --- a/synapse/util/expiringcache.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - - -logger = logging.getLogger(__name__) - - -class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False): - """ - Args: - cache_name (str): Name of this cache, used for logging. - clock (Clock) - max_len (int): Max size of dict. If the dict grows larger than this - then the oldest items get automatically evicted. Default is 0, - which indicates there is no max limit. - expiry_ms (int): How long before an item is evicted from the cache - in milliseconds. Default is 0, indicating items never get - evicted based on time. - reset_expiry_on_get (bool): If true, will reset the expiry time for - an item on access. Defaults to False. - - """ - self._cache_name = cache_name - - self._clock = clock - - self._max_len = max_len - self._expiry_ms = expiry_ms - - self._reset_expiry_on_get = reset_expiry_on_get - - self._cache = {} - - def start(self): - if not self._expiry_ms: - # Don't bother starting the loop if things never expire - return - - def f(): - self._prune_cache() - - self._clock.looping_call(f, self._expiry_ms/2) - - def __setitem__(self, key, value): - now = self._clock.time_msec() - self._cache[key] = _CacheEntry(now, value) - - # Evict if there are now too many items - if self._max_len and len(self._cache.keys()) > self._max_len: - sorted_entries = sorted( - self._cache.items(), - key=lambda (k, v): v.time, - ) - - for k, _ in sorted_entries[self._max_len:]: - self._cache.pop(k) - - def __getitem__(self, key): - entry = self._cache[key] - - if self._reset_expiry_on_get: - entry.time = self._clock.time_msec() - - return entry.value - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def _prune_cache(self): - if not self._expiry_ms: - # zero expiry time means don't expire. This should never get called - # since we have this check in start too. - return - begin_length = len(self._cache) - - now = self._clock.time_msec() - - keys_to_delete = set() - - for key, cache_entry in self._cache.items(): - if now - cache_entry.time > self._expiry_ms: - keys_to_delete.add(key) - - for k in keys_to_delete: - self._cache.pop(k) - - logger.debug( - "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self._cache.keys()) - ) - - -class _CacheEntry(object): - def __init__(self, time, value): - self.time = time - self.value = value diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py deleted file mode 100644 index cacd7e45fa..0000000000 --- a/synapse/util/lrucache.py +++ /dev/null @@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from functools import wraps -import threading - - -class LruCache(object): - """Least-recently-used cache.""" - def __init__(self, max_size): - cache = {} - list_root = [] - list_root[:] = [list_root, list_root, None, None] - - PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 - - lock = threading.Lock() - - def synchronized(f): - @wraps(f) - def inner(*args, **kwargs): - with lock: - return f(*args, **kwargs) - - return inner - - def add_node(key, value): - prev_node = list_root - next_node = prev_node[NEXT] - node = [prev_node, next_node, key, value] - prev_node[NEXT] = node - next_node[PREV] = node - cache[key] = node - - def move_node_to_front(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node - prev_node = list_root - next_node = prev_node[NEXT] - node[PREV] = prev_node - node[NEXT] = next_node - prev_node[NEXT] = node - next_node[PREV] = node - - def delete_node(node): - prev_node = node[PREV] - next_node = node[NEXT] - prev_node[NEXT] = next_node - next_node[PREV] = prev_node - cache.pop(node[KEY], None) - - @synchronized - def cache_get(key, default=None): - node = cache.get(key, None) - if node is not None: - move_node_to_front(node) - return node[VALUE] - else: - return default - - @synchronized - def cache_set(key, value): - node = cache.get(key, None) - if node is not None: - move_node_to_front(node) - node[VALUE] = value - else: - add_node(key, value) - if len(cache) > max_size: - delete_node(list_root[PREV]) - - @synchronized - def cache_set_default(key, value): - node = cache.get(key, None) - if node is not None: - return node[VALUE] - else: - add_node(key, value) - if len(cache) > max_size: - delete_node(list_root[PREV]) - return value - - @synchronized - def cache_pop(key, default=None): - node = cache.get(key, None) - if node: - delete_node(node) - return node[VALUE] - else: - return default - - @synchronized - def cache_clear(): - list_root[NEXT] = list_root - list_root[PREV] = list_root - cache.clear() - - @synchronized - def cache_len(): - return len(cache) - - @synchronized - def cache_contains(key): - return key in cache - - self.sentinel = object() - self.get = cache_get - self.set = cache_set - self.setdefault = cache_set_default - self.pop = cache_pop - self.len = cache_len - self.contains = cache_contains - self.clear = cache_clear - - def __getitem__(self, key): - result = self.get(key, self.sentinel) - if result is self.sentinel: - raise KeyError() - else: - return result - - def __setitem__(self, key, value): - self.set(key, value) - - def __delitem__(self, key, value): - result = self.pop(key, self.sentinel) - if result is self.sentinel: - raise KeyError() - - def __len__(self): - return self.len() - - def __contains__(self, key): - return self.contains(key) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index abee2f631d..e72cace8ff 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.util.async import ObservableDeferred -from synapse.storage._base import Cache, cached +from synapse.util.caches.descriptors import Cache, cached class CacheTestCase(unittest.TestCase): diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 79bc1225d6..54ff26cd97 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -17,7 +17,7 @@ from twisted.internet import defer from tests import unittest -from synapse.util.dictionary_cache import DictionaryCache +from synapse.util.caches.dictionary_cache import DictionaryCache class DictCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index ab934bf928..fc5a904323 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -16,7 +16,7 @@ from .. import unittest -from synapse.util.lrucache import LruCache +from synapse.util.caches.lrucache import LruCache class LruCacheTestCase(unittest.TestCase): @@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase): cache["key"] = 1 self.assertEquals(cache.pop("key"), 1) self.assertEquals(cache.pop("key"), None) - - -- cgit 1.4.1 From 0cceb2ac92cde0a4289adfc6e9000c7b1c54bdae Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 13 Aug 2015 17:27:46 +0100 Subject: Add a few strategic new lines to break up the on_query_client_keys and on_claim_client_keys methods in federation_server.py --- synapse/federation/federation_server.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c32908ac28..725c6f3fa5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -323,13 +323,16 @@ class FederationServer(FederationBase): else: for device_id in device_ids: query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(query) + json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) + defer.returnValue({"device_keys": json_result}) @defer.inlineCallbacks @@ -339,7 +342,9 @@ class FederationServer(FederationBase): for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) + results = yield self.store.claim_e2e_one_time_keys(query) + json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -347,6 +352,7 @@ class FederationServer(FederationBase): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } + defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks -- cgit 1.4.1