summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/api/filtering.py261
-rw-r--r--synapse/crypto/keyring.py23
-rw-r--r--synapse/handlers/e2e_keys.py2
-rw-r--r--synapse/python_dependencies.py1
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/util/__init__.py10
-rw-r--r--synapse/util/caches/descriptors.py135
-rw-r--r--synapse/util/retryutils.py5
9 files changed, 277 insertions, 170 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 03a215ab1b..9dbc7993df 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes
 from synapse.types import UserID
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util import logcontext
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -209,8 +209,7 @@ class Auth(object):
                 default=[""]
             )[0]
             if user and access_token and ip_addr:
-                preserve_context_over_fn(
-                    self.store.insert_client_ip,
+                logcontext.preserve_fn(self.store.insert_client_ip)(
                     user=user,
                     access_token=access_token,
                     ip=ip_addr,
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 47f0cf0fa9..83206348e5 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,10 +15,172 @@
 from synapse.api.errors import SynapseError
 from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, RoomID
-
 from twisted.internet import defer
 
 import ujson as json
+import jsonschema
+from jsonschema import FormatChecker
+
+FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        # TODO: We don't limit event type values but we probably should...
+        # check types are valid event types
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        }
+    }
+}
+
+ROOM_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "ephemeral": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "include_leave": {
+            "type": "boolean"
+        },
+        "state": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "timeline": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+    }
+}
+
+ROOM_EVENT_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "contains_url": {
+            "type": "boolean"
+        }
+    }
+}
+
+USER_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_user_id"
+    }
+}
+
+ROOM_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_room_id"
+    }
+}
+
+USER_FILTER_SCHEMA = {
+    "$schema": "http://json-schema.org/draft-04/schema#",
+    "description": "schema for a Sync filter",
+    "type": "object",
+    "definitions": {
+        "room_id_array": ROOM_ID_ARRAY_SCHEMA,
+        "user_id_array": USER_ID_ARRAY_SCHEMA,
+        "filter": FILTER_SCHEMA,
+        "room_filter": ROOM_FILTER_SCHEMA,
+        "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+    },
+    "properties": {
+        "presence": {
+            "$ref": "#/definitions/filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/filter"
+        },
+        "room": {
+            "$ref": "#/definitions/room_filter"
+        },
+        "event_format": {
+            "type": "string",
+            "enum": ["client", "federation"]
+        },
+        "event_fields": {
+            "type": "array",
+            "items": {
+                "type": "string",
+                # Don't allow '\\' in event field filters. This makes matching
+                # events a lot easier as we can then use a negative lookbehind
+                # assertion to split '\.' If we allowed \\ then it would
+                # incorrectly split '\\.' See synapse.events.utils.serialize_event
+                "pattern": "^((?!\\\).)*$"
+            }
+        }
+    },
+    "additionalProperties": False
+}
+
+
+@FormatChecker.cls_checks('matrix_room_id')
+def matrix_room_id_validator(room_id_str):
+    return RoomID.from_string(room_id_str)
+
+
+@FormatChecker.cls_checks('matrix_user_id')
+def matrix_user_id_validator(user_id_str):
+    return UserID.from_string(user_id_str)
 
 
 class Filtering(object):
@@ -53,98 +215,11 @@ class Filtering(object):
         # NB: Filters are the complete json blobs. "Definitions" are an
         # individual top-level key e.g. public_user_data. Filters are made of
         # many definitions.
-
-        top_level_definitions = [
-            "presence", "account_data"
-        ]
-
-        room_level_definitions = [
-            "state", "timeline", "ephemeral", "account_data"
-        ]
-
-        for key in top_level_definitions:
-            if key in user_filter_json:
-                self._check_definition(user_filter_json[key])
-
-        if "room" in user_filter_json:
-            self._check_definition_room_lists(user_filter_json["room"])
-            for key in room_level_definitions:
-                if key in user_filter_json["room"]:
-                    self._check_definition(user_filter_json["room"][key])
-
-        if "event_fields" in user_filter_json:
-            if type(user_filter_json["event_fields"]) != list:
-                raise SynapseError(400, "event_fields must be a list of strings")
-            for field in user_filter_json["event_fields"]:
-                if not isinstance(field, basestring):
-                    raise SynapseError(400, "Event field must be a string")
-                # Don't allow '\\' in event field filters. This makes matching
-                # events a lot easier as we can then use a negative lookbehind
-                # assertion to split '\.' If we allowed \\ then it would
-                # incorrectly split '\\.' See synapse.events.utils.serialize_event
-                if r'\\' in field:
-                    raise SynapseError(
-                        400, r'The escape character \ cannot itself be escaped'
-                    )
-
-    def _check_definition_room_lists(self, definition):
-        """Check that "rooms" and "not_rooms" are lists of room ids if they
-        are present
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # check rooms are valid room IDs
-        room_id_keys = ["rooms", "not_rooms"]
-        for key in room_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for room_id in definition[key]:
-                    RoomID.from_string(room_id)
-
-    def _check_definition(self, definition):
-        """Check if the provided definition is valid.
-
-        This inspects not only the types but also the values to make sure they
-        make sense.
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # NB: Filters are the complete json blobs. "Definitions" are an
-        # individual top-level key e.g. public_user_data. Filters are made of
-        # many definitions.
-        if type(definition) != dict:
-            raise SynapseError(
-                400, "Expected JSON object, not %s" % (definition,)
-            )
-
-        self._check_definition_room_lists(definition)
-
-        # check senders are valid user IDs
-        user_id_keys = ["senders", "not_senders"]
-        for key in user_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for user_id in definition[key]:
-                    UserID.from_string(user_id)
-
-        # TODO: We don't limit event type values but we probably should...
-        # check types are valid event types
-        event_keys = ["types", "not_types"]
-        for key in event_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for event_type in definition[key]:
-                    if not isinstance(event_type, basestring):
-                        raise SynapseError(400, "Event type should be a string")
+        try:
+            jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
+                                format_checker=FormatChecker())
+        except jsonschema.ValidationError as e:
+            raise SynapseError(400, e.message)
 
 
 class FilterCollection(object):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c4bc4f4d31..1bb27edc0f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -237,8 +237,14 @@ class Keyring(object):
             d.addBoth(rm, server_name)
 
     def get_server_verify_keys(self, verify_requests):
-        """Takes a dict of KeyGroups and tries to find at least one key for
-        each group.
+        """Tries to find at least one key for each verify request
+
+        For each verify_request, verify_request.deferred is called back with
+        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+        with a SynapseError if none of the keys are found.
+
+        Args:
+            verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
         # These are functions that produce keys given a list of key ids
@@ -251,8 +257,11 @@ class Keyring(object):
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
+                # dict[str, dict[str, VerifyKey]]: results so far.
+                # map server_name -> key_id -> VerifyKey
                 merged_results = {}
 
+                # dict[str, set(str)]: keys to fetch for each server
                 missing_keys = {}
                 for verify_request in verify_requests:
                     missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -314,6 +323,16 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
+        """
+
+        Args:
+            server_name_and_key_ids (list[(str, iterable[str])]):
+                list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+        Returns:
+            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+                server_name -> key_id -> VerifyKey
+        """
         res = yield preserve_context_over_deferred(defer.gatherResults(
             [
                 preserve_fn(self.store.get_server_verify_keys)(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index a33135de67..c2b38d72a9 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -308,7 +308,7 @@ class E2eKeysHandler(object):
         # old access_token without an associated device_id. Either way, we
         # need to double-check the device is registered to avoid ending up with
         # keys without a corresponding device.
-        self.device_handler.check_device_registered(user_id, device_id)
+        yield self.device_handler.check_device_registered(user_id, device_id)
 
         result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c4777b2a2b..ed7f1c89ad 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -19,6 +19,7 @@ from distutils.version import LooseVersion
 logger = logging.getLogger(__name__)
 
 REQUIREMENTS = {
+    "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
     "frozendict>=0.4": ["frozendict"],
     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9ddd..3b5e0a4fb9 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
         key_ids
         Args:
             server_name (str): The name of the server.
-            key_ids (list of str): List of key_ids to try and look up.
+            key_ids (iterable[str]): key_ids to try and look up.
         Returns:
-            (list of VerifyKey): The verification keys.
+            Deferred: resolves to dict[str, VerifyKey]: map from
+               key_id to verification key.
         """
         keys = {}
         for key_id in key_ids:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc480108..98a5a26ac5 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 class DeferredTimedOutError(SynapseError):
     def __init__(self):
-        super(SynapseError).__init__(504, "Timed out")
+        super(SynapseError, self).__init__(504, "Timed out")
 
 
 def unwrapFirstError(failure):
@@ -93,8 +93,10 @@ class Clock(object):
         ret_deferred = defer.Deferred()
 
         def timed_out_fn():
+            e = DeferredTimedOutError()
+
             try:
-                ret_deferred.errback(DeferredTimedOutError())
+                ret_deferred.errback(e)
             except:
                 pass
 
@@ -114,7 +116,7 @@ class Clock(object):
 
         ret_deferred.addBoth(cancel)
 
-        def sucess(res):
+        def success(res):
             try:
                 ret_deferred.callback(res)
             except:
@@ -128,7 +130,7 @@ class Clock(object):
             except:
                 pass
 
-        given_deferred.addCallbacks(callback=sucess, errback=err)
+        given_deferred.addCallbacks(callback=success, errback=err)
 
         timer = self.call_later(time_out, timed_out_fn)
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..19595df422 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -189,7 +189,55 @@ class Cache(object):
         self.cache.clear()
 
 
-class CacheDescriptor(object):
+class _CacheDescriptorBase(object):
+    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
+        arg_spec = inspect.getargspec(orig)
+        all_args = arg_spec.args
+
+        if "cache_context" in all_args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
+
+        if num_args is None:
+            num_args = len(all_args) - 1
+            if cache_context:
+                num_args -= 1
+
+        if len(all_args) < num_args + 1:
+            raise Exception(
+                "Not enough explicit positional arguments to key off for %r: "
+                "got %i args, but wanted %i. (@cached cannot key off *args or "
+                "**kwargs)"
+                % (orig.__name__, len(all_args), num_args)
+            )
+
+        self.num_args = num_args
+        self.arg_names = all_args[1:num_args + 1]
+
+        if "cache_context" in self.arg_names:
+            raise Exception(
+                "cache_context arg cannot be included among the cache keys"
+            )
+
+        self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +265,24 @@ class CacheDescriptor(object):
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
+    Args:
+        num_args (int): number of positional arguments (excluding ``self`` and
+            ``cache_context``) to use as cache keys. Defaults to all named
+            args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
                  inlineCallbacks=False, cache_context=False, iterable=False):
-        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
-        self.orig = orig
+        super(CacheDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.max_entries = max_entries
-        self.num_args = num_args
         self.tree = tree
-
         self.iterable = iterable
 
-        all_args = inspect.getargspec(orig)
-        self.arg_names = all_args.args[1:num_args + 1]
-
-        if "cache_context" in all_args.args:
-            if not cache_context:
-                raise ValueError(
-                    "Cannot have a 'cache_context' arg without setting"
-                    " cache_context=True"
-                )
-            try:
-                self.arg_names.remove("cache_context")
-            except ValueError:
-                pass
-        elif cache_context:
-            raise ValueError(
-                "Cannot have cache_context=True without having an arg"
-                " named `cache_context`"
-            )
-
-        self.add_cache_context = cache_context
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwargs)"
-                % (orig.__name__,)
-            )
-
     def __get__(self, obj, objtype=None):
         cache = Cache(
             name=self.orig.__name__,
@@ -338,48 +358,36 @@ class CacheDescriptor(object):
         return wrapped
 
 
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
     the list of missing keys to the wrapped fucntion.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+    def __init__(self, orig, cached_method_name, list_name, num_args=None,
                  inlineCallbacks=False):
         """
         Args:
             orig (function)
-            method_name (str); The name of the chached method.
+            cached_method_name (str): The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int)
+            num_args (int): number of positional arguments (excluding ``self``,
+                but including list_name) to use as cache keys. Defaults to all
+                named args of the function.
             inlineCallbacks (bool): Whether orig is a generator that should
                 be wrapped by defer.inlineCallbacks
         """
-        self.orig = orig
+        super(CacheListDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
-
         self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
@@ -487,7 +495,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
            iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
@@ -499,8 +507,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
-                          iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+                          cache_context=False, iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -512,7 +520,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
     )
 
 
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +533,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
         cache (Cache): The underlying cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
-        num_args (int): Number of arguments to use as the key in the cache.
+        num_args (int): Number of arguments to use as the key in the cache
+            (including list_name). Defaults to all named parameters.
         inlineCallbacks (bool): Should the function be wrapped in an
             `defer.inlineCallbacks`?
 
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 7f5299bd32..4fa9d1a03c 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import synapse.util.logcontext
 from twisted.internet import defer
 
 from synapse.api.errors import CodeMessageException
@@ -194,4 +194,5 @@ class RetryDestinationLimiter(object):
                     "Failed to store set_destination_retry_timings",
                 )
 
-        store_retry_timings()
+        # we deliberately do this in the background.
+        synapse.util.logcontext.preserve_fn(store_retry_timings)()