diff options
-rw-r--r-- | README.rst | 2 | ||||
-rw-r--r-- | synapse/api/auth.py | 5 | ||||
-rw-r--r-- | synapse/api/filtering.py | 261 | ||||
-rw-r--r-- | synapse/crypto/keyring.py | 23 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 2 | ||||
-rw-r--r-- | synapse/python_dependencies.py | 1 | ||||
-rw-r--r-- | synapse/storage/keys.py | 5 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 135 | ||||
-rw-r--r-- | synapse/util/retryutils.py | 5 | ||||
-rw-r--r-- | tests/api/test_filtering.py | 67 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_filter.py | 4 | ||||
-rw-r--r-- | tests/storage/test_keys.py | 53 | ||||
-rw-r--r-- | tests/util/caches/__init__.py | 14 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 86 |
14 files changed, 494 insertions, 169 deletions
diff --git a/README.rst b/README.rst index 56f92a680d..e975da8162 100644 --- a/README.rst +++ b/README.rst @@ -808,7 +808,7 @@ directory of your choice:: Synapse has a number of external dependencies, that are easiest to install using pip and a virtualenv:: - virtualenv env + virtualenv -p python2.7 env source env/bin/activate python synapse/python_dependencies.py | xargs pip install pip install lxml mock diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 03a215ab1b..9dbc7993df 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -23,7 +23,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes from synapse.types import UserID -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util import logcontext from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -209,8 +209,7 @@ class Auth(object): default=[""] )[0] if user and access_token and ip_addr: - preserve_context_over_fn( - self.store.insert_client_ip, + logcontext.preserve_fn(self.store.insert_client_ip)( user=user, access_token=access_token, ip=ip_addr, diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 47f0cf0fa9..83206348e5 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,10 +15,172 @@ from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState from synapse.types import UserID, RoomID - from twisted.internet import defer import ujson as json +import jsonschema +from jsonschema import FormatChecker + +FILTER_SCHEMA = { + "additionalProperties": False, + "type": "object", + "properties": { + "limit": { + "type": "number" + }, + "senders": { + "$ref": "#/definitions/user_id_array" + }, + "not_senders": { + "$ref": "#/definitions/user_id_array" + }, + # TODO: We don't limit event type values but we probably should... + # check types are valid event types + "types": { + "type": "array", + "items": { + "type": "string" + } + }, + "not_types": { + "type": "array", + "items": { + "type": "string" + } + } + } +} + +ROOM_FILTER_SCHEMA = { + "additionalProperties": False, + "type": "object", + "properties": { + "not_rooms": { + "$ref": "#/definitions/room_id_array" + }, + "rooms": { + "$ref": "#/definitions/room_id_array" + }, + "ephemeral": { + "$ref": "#/definitions/room_event_filter" + }, + "include_leave": { + "type": "boolean" + }, + "state": { + "$ref": "#/definitions/room_event_filter" + }, + "timeline": { + "$ref": "#/definitions/room_event_filter" + }, + "account_data": { + "$ref": "#/definitions/room_event_filter" + }, + } +} + +ROOM_EVENT_FILTER_SCHEMA = { + "additionalProperties": False, + "type": "object", + "properties": { + "limit": { + "type": "number" + }, + "senders": { + "$ref": "#/definitions/user_id_array" + }, + "not_senders": { + "$ref": "#/definitions/user_id_array" + }, + "types": { + "type": "array", + "items": { + "type": "string" + } + }, + "not_types": { + "type": "array", + "items": { + "type": "string" + } + }, + "rooms": { + "$ref": "#/definitions/room_id_array" + }, + "not_rooms": { + "$ref": "#/definitions/room_id_array" + }, + "contains_url": { + "type": "boolean" + } + } +} + +USER_ID_ARRAY_SCHEMA = { + "type": "array", + "items": { + "type": "string", + "format": "matrix_user_id" + } +} + +ROOM_ID_ARRAY_SCHEMA = { + "type": "array", + "items": { + "type": "string", + "format": "matrix_room_id" + } +} + +USER_FILTER_SCHEMA = { + "$schema": "http://json-schema.org/draft-04/schema#", + "description": "schema for a Sync filter", + "type": "object", + "definitions": { + "room_id_array": ROOM_ID_ARRAY_SCHEMA, + "user_id_array": USER_ID_ARRAY_SCHEMA, + "filter": FILTER_SCHEMA, + "room_filter": ROOM_FILTER_SCHEMA, + "room_event_filter": ROOM_EVENT_FILTER_SCHEMA + }, + "properties": { + "presence": { + "$ref": "#/definitions/filter" + }, + "account_data": { + "$ref": "#/definitions/filter" + }, + "room": { + "$ref": "#/definitions/room_filter" + }, + "event_format": { + "type": "string", + "enum": ["client", "federation"] + }, + "event_fields": { + "type": "array", + "items": { + "type": "string", + # Don't allow '\\' in event field filters. This makes matching + # events a lot easier as we can then use a negative lookbehind + # assertion to split '\.' If we allowed \\ then it would + # incorrectly split '\\.' See synapse.events.utils.serialize_event + "pattern": "^((?!\\\).)*$" + } + } + }, + "additionalProperties": False +} + + +@FormatChecker.cls_checks('matrix_room_id') +def matrix_room_id_validator(room_id_str): + return RoomID.from_string(room_id_str) + + +@FormatChecker.cls_checks('matrix_user_id') +def matrix_user_id_validator(user_id_str): + return UserID.from_string(user_id_str) class Filtering(object): @@ -53,98 +215,11 @@ class Filtering(object): # NB: Filters are the complete json blobs. "Definitions" are an # individual top-level key e.g. public_user_data. Filters are made of # many definitions. - - top_level_definitions = [ - "presence", "account_data" - ] - - room_level_definitions = [ - "state", "timeline", "ephemeral", "account_data" - ] - - for key in top_level_definitions: - if key in user_filter_json: - self._check_definition(user_filter_json[key]) - - if "room" in user_filter_json: - self._check_definition_room_lists(user_filter_json["room"]) - for key in room_level_definitions: - if key in user_filter_json["room"]: - self._check_definition(user_filter_json["room"][key]) - - if "event_fields" in user_filter_json: - if type(user_filter_json["event_fields"]) != list: - raise SynapseError(400, "event_fields must be a list of strings") - for field in user_filter_json["event_fields"]: - if not isinstance(field, basestring): - raise SynapseError(400, "Event field must be a string") - # Don't allow '\\' in event field filters. This makes matching - # events a lot easier as we can then use a negative lookbehind - # assertion to split '\.' If we allowed \\ then it would - # incorrectly split '\\.' See synapse.events.utils.serialize_event - if r'\\' in field: - raise SynapseError( - 400, r'The escape character \ cannot itself be escaped' - ) - - def _check_definition_room_lists(self, definition): - """Check that "rooms" and "not_rooms" are lists of room ids if they - are present - - Args: - definition(dict): The filter definition - Raises: - SynapseError: If there was a problem with this definition. - """ - # check rooms are valid room IDs - room_id_keys = ["rooms", "not_rooms"] - for key in room_id_keys: - if key in definition: - if type(definition[key]) != list: - raise SynapseError(400, "Expected %s to be a list." % key) - for room_id in definition[key]: - RoomID.from_string(room_id) - - def _check_definition(self, definition): - """Check if the provided definition is valid. - - This inspects not only the types but also the values to make sure they - make sense. - - Args: - definition(dict): The filter definition - Raises: - SynapseError: If there was a problem with this definition. - """ - # NB: Filters are the complete json blobs. "Definitions" are an - # individual top-level key e.g. public_user_data. Filters are made of - # many definitions. - if type(definition) != dict: - raise SynapseError( - 400, "Expected JSON object, not %s" % (definition,) - ) - - self._check_definition_room_lists(definition) - - # check senders are valid user IDs - user_id_keys = ["senders", "not_senders"] - for key in user_id_keys: - if key in definition: - if type(definition[key]) != list: - raise SynapseError(400, "Expected %s to be a list." % key) - for user_id in definition[key]: - UserID.from_string(user_id) - - # TODO: We don't limit event type values but we probably should... - # check types are valid event types - event_keys = ["types", "not_types"] - for key in event_keys: - if key in definition: - if type(definition[key]) != list: - raise SynapseError(400, "Expected %s to be a list." % key) - for event_type in definition[key]: - if not isinstance(event_type, basestring): - raise SynapseError(400, "Event type should be a string") + try: + jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, + format_checker=FormatChecker()) + except jsonschema.ValidationError as e: + raise SynapseError(400, e.message) class FilterCollection(object): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 80f27f8c53..0a21392a62 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -238,8 +238,14 @@ class Keyring(object): d.addBoth(rm, server_name) def get_server_verify_keys(self, verify_requests): - """Takes a dict of KeyGroups and tries to find at least one key for - each group. + """Tries to find at least one key for each verify request + + For each verify_request, verify_request.deferred is called back with + params (server_name, key_id, VerifyKey) if a key is found, or errbacked + with a SynapseError if none of the keys are found. + + Args: + verify_requests (list[VerifyKeyRequest]): list of verify requests """ # These are functions that produce keys given a list of key ids @@ -252,8 +258,11 @@ class Keyring(object): @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): + # dict[str, dict[str, VerifyKey]]: results so far. + # map server_name -> key_id -> VerifyKey merged_results = {} + # dict[str, set(str)]: keys to fetch for each server missing_keys = {} for verify_request in verify_requests: missing_keys.setdefault(verify_request.server_name, set()).update( @@ -315,6 +324,16 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): + """ + + Args: + server_name_and_key_ids (list[(str, iterable[str])]): + list of (server_name, iterable[key_id]) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from + server_name -> key_id -> VerifyKey + """ res = yield preserve_context_over_deferred(defer.gatherResults( [ preserve_fn(self.store.get_server_verify_keys)( diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index e40495d1ab..c02d41a74c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -316,7 +316,7 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - self.device_handler.check_device_registered(user_id, device_id) + yield self.device_handler.check_device_registered(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index c4777b2a2b..ed7f1c89ad 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -19,6 +19,7 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) REQUIREMENTS = { + "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], "frozendict>=0.4": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 86b37b9ddd..3b5e0a4fb9 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore): key_ids Args: server_name (str): The name of the server. - key_ids (list of str): List of key_ids to try and look up. + key_ids (iterable[str]): key_ids to try and look up. Returns: - (list of VerifyKey): The verification keys. + Deferred: resolves to dict[str, VerifyKey]: map from + key_id to verification key. """ keys = {} for key_id in key_ids: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 998de70d29..19595df422 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -189,7 +189,55 @@ class Cache(object): self.cache.clear() -class CacheDescriptor(object): +class _CacheDescriptorBase(object): + def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): + self.orig = orig + + if inlineCallbacks: + self.function_to_call = defer.inlineCallbacks(orig) + else: + self.function_to_call = orig + + arg_spec = inspect.getargspec(orig) + all_args = arg_spec.args + + if "cache_context" in all_args: + if not cache_context: + raise ValueError( + "Cannot have a 'cache_context' arg without setting" + " cache_context=True" + ) + elif cache_context: + raise ValueError( + "Cannot have cache_context=True without having an arg" + " named `cache_context`" + ) + + if num_args is None: + num_args = len(all_args) - 1 + if cache_context: + num_args -= 1 + + if len(all_args) < num_args + 1: + raise Exception( + "Not enough explicit positional arguments to key off for %r: " + "got %i args, but wanted %i. (@cached cannot key off *args or " + "**kwargs)" + % (orig.__name__, len(all_args), num_args) + ) + + self.num_args = num_args + self.arg_names = all_args[1:num_args + 1] + + if "cache_context" in self.arg_names: + raise Exception( + "cache_context arg cannot be included among the cache keys" + ) + + self.add_cache_context = cache_context + + +class CacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -217,52 +265,24 @@ class CacheDescriptor(object): r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate) defer.returnValue(r1 + r2) + Args: + num_args (int): number of positional arguments (excluding ``self`` and + ``cache_context``) to use as cache keys. Defaults to all named + args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=1, tree=False, + def __init__(self, orig, max_entries=1000, num_args=None, tree=False, inlineCallbacks=False, cache_context=False, iterable=False): - max_entries = int(max_entries * CACHE_SIZE_FACTOR) - self.orig = orig + super(CacheDescriptor, self).__init__( + orig, num_args=num_args, inlineCallbacks=inlineCallbacks, + cache_context=cache_context) - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig + max_entries = int(max_entries * CACHE_SIZE_FACTOR) self.max_entries = max_entries - self.num_args = num_args self.tree = tree - self.iterable = iterable - all_args = inspect.getargspec(orig) - self.arg_names = all_args.args[1:num_args + 1] - - if "cache_context" in all_args.args: - if not cache_context: - raise ValueError( - "Cannot have a 'cache_context' arg without setting" - " cache_context=True" - ) - try: - self.arg_names.remove("cache_context") - except ValueError: - pass - elif cache_context: - raise ValueError( - "Cannot have cache_context=True without having an arg" - " named `cache_context`" - ) - - self.add_cache_context = cache_context - - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwargs)" - % (orig.__name__,) - ) - def __get__(self, obj, objtype=None): cache = Cache( name=self.orig.__name__, @@ -338,48 +358,36 @@ class CacheDescriptor(object): return wrapped -class CacheListDescriptor(object): +class CacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cached_method_name, list_name, num_args=1, + def __init__(self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False): """ Args: orig (function) - method_name (str); The name of the chached method. + cached_method_name (str): The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list - num_args (int) + num_args (int): number of positional arguments (excluding ``self``, + but including list_name) to use as cache keys. Defaults to all + named args of the function. inlineCallbacks (bool): Whether orig is a generator that should be wrapped by defer.inlineCallbacks """ - self.orig = orig + super(CacheListDescriptor, self).__init__( + orig, num_args=num_args, inlineCallbacks=inlineCallbacks) - if inlineCallbacks: - self.function_to_call = defer.inlineCallbacks(orig) - else: - self.function_to_call = orig - - self.num_args = num_args self.list_name = list_name - self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cached_method_name = cached_method_name self.sentinel = object() - if len(self.arg_names) < self.num_args: - raise Exception( - "Not enough explicit positional arguments to key off of for %r." - " (@cached cannot key off of *args or **kwars)" - % (orig.__name__,) - ) - if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." @@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, +def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False): return lambda orig: CacheDescriptor( orig, @@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False, - iterable=False): +def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, + cache_context=False, iterable=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex ) -def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False) cache (Cache): The underlying cache to use. list_name (str): The name of the argument that is the list to use to do batch lookups in the cache. - num_args (int): Number of arguments to use as the key in the cache. + num_args (int): Number of arguments to use as the key in the cache + (including list_name). Defaults to all named parameters. inlineCallbacks (bool): Should the function be wrapped in an `defer.inlineCallbacks`? diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 153ef001ad..b68e8c4e9f 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import synapse.util.logcontext from twisted.internet import defer from synapse.api.errors import CodeMessageException @@ -173,4 +173,5 @@ class RetryDestinationLimiter(object): "Failed to store set_destination_retry_timings", ) - store_retry_timings() + # we deliberately do this in the background. + synapse.util.logcontext.preserve_fn(store_retry_timings)() diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 50e8607c14..dcceca7f3e 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -23,6 +23,9 @@ from tests.utils import ( from synapse.api.filtering import Filter from synapse.events import FrozenEvent +from synapse.api.errors import SynapseError + +import jsonschema user_localpart = "test_user" @@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase): self.datastore = hs.get_datastore() + def test_errors_on_invalid_filters(self): + invalid_filters = [ + {"boom": {}}, + {"account_data": "Hello World"}, + {"event_fields": ["\\foo"]}, + {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, + {"event_format": "other"}, + {"room": {"not_rooms": ["#foo:pik-test"]}}, + {"presence": {"senders": ["@bar;pik.test.com"]}} + ] + for filter in invalid_filters: + with self.assertRaises(SynapseError) as check_filter_error: + self.filtering.check_valid_filter(filter) + self.assertIsInstance(check_filter_error.exception, SynapseError) + + def test_valid_filters(self): + valid_filters = [ + { + "room": { + "timeline": {"limit": 20}, + "state": {"not_types": ["m.room.member"]}, + "ephemeral": {"limit": 0, "not_types": ["*"]}, + "include_leave": False, + "rooms": ["!dee:pik-test"], + "not_rooms": ["!gee:pik-test"], + "account_data": {"limit": 0, "types": ["*"]} + } + }, + { + "room": { + "state": { + "types": ["m.room.*"], + "not_rooms": ["!726s6s6q:example.com"] + }, + "timeline": { + "limit": 10, + "types": ["m.room.message"], + "not_rooms": ["!726s6s6q:example.com"], + "not_senders": ["@spam:example.com"] + }, + "ephemeral": { + "types": ["m.receipt", "m.typing"], + "not_rooms": ["!726s6s6q:example.com"], + "not_senders": ["@spam:example.com"] + } + }, + "presence": { + "types": ["m.presence"], + "not_senders": ["@alice:example.com"] + }, + "event_format": "client", + "event_fields": ["type", "content", "sender"] + } + ] + for filter in valid_filters: + try: + self.filtering.check_valid_filter(filter) + except jsonschema.ValidationError as e: + self.fail(e) + + def test_limits_are_applied(self): + # TODO + pass + def test_definition_types_works_with_literals(self): definition = { "types": ["m.room.message", "org.matrix.foo.bar"] diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 3d27d03cbf..76b833e119 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.TestCase): USER_ID = "@apple:test" - EXAMPLE_FILTER = {"type": ["m.*"]} - EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' + EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} + EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}' TO_REGISTER = [filter] @defer.inlineCallbacks diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py new file mode 100644 index 0000000000..0be790d8f8 --- /dev/null +++ b/tests/storage/test_keys.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import signedjson.key +from twisted.internet import defer + +import tests.unittest +import tests.utils + + +class KeyStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(KeyStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.keys.KeyStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_server_verify_keys(self): + key1 = signedjson.key.decode_verify_key_base64( + "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + ) + key2 = signedjson.key.decode_verify_key_base64( + "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" + ) + yield self.store.store_server_verify_key( + "server1", "from_server", 0, key1 + ) + yield self.store.store_server_verify_key( + "server1", "from_server", 0, key2 + ) + + res = yield self.store.get_server_verify_keys( + "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) + + self.assertEqual(len(res.keys()), 2) + self.assertEqual(res["ed25519:key1"].version, "key1") + self.assertEqual(res["ed25519:key2"].version, "key2") diff --git a/tests/util/caches/__init__.py b/tests/util/caches/__init__.py new file mode 100644 index 0000000000..451dae3b6c --- /dev/null +++ b/tests/util/caches/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py new file mode 100644 index 0000000000..419281054d --- /dev/null +++ b/tests/util/caches/test_descriptors.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import mock +from twisted.internet import defer +from synapse.util.caches import descriptors +from tests import unittest + + +class DescriptorTestCase(unittest.TestCase): + @defer.inlineCallbacks + def test_cache(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + + obj.mock.return_value = 'fish' + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = 'chips' + r = yield obj.fn(1, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(1, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_num_args(self): + """Only the first num_args arguments should matter to the cache""" + + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached(num_args=1) + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = 'fish' + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = 'chips' + r = yield obj.fn(2, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_called_once_with(2, 3) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4) + self.assertEqual(r, 'fish') + r = yield obj.fn(2, 5) + self.assertEqual(r, 'chips') + obj.mock.assert_not_called() |