summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/api/filtering.py109
-rw-r--r--synapse/rest/client/v2_alpha/filter.py2
-rw-r--r--synapse/storage/filtering.py4
-rw-r--r--tests/api/test_filtering.py24
4 files changed, 128 insertions, 11 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 20b6951d47..6c7a73b6d5 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -13,7 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, RoomID
 
 
 class Filtering(object):
@@ -25,10 +26,110 @@ class Filtering(object):
     def get_user_filter(self, user_localpart, filter_id):
         return self.store.get_user_filter(user_localpart, filter_id)
 
-    def add_user_filter(self, user_localpart, definition):
-        # TODO(paul): implement sanity checking of the definition
-        return self.store.add_user_filter(user_localpart, definition)
+    def add_user_filter(self, user_localpart, user_filter):
+        self._check_valid_filter(user_filter)
+        return self.store.add_user_filter(user_localpart, user_filter)
 
     # TODO(paul): surely we should probably add a delete_user_filter or
     #   replace_user_filter at some point? There's no REST API specified for
     #   them however
+
+    def _check_valid_filter(self, user_filter):
+        """Check if the provided filter is valid.
+
+        This inspects all definitions contained within the filter.
+
+        Args:
+            user_filter(dict): The filter
+        Raises:
+            SynapseError: If the filter is not valid.
+        """
+        # NB: Filters are the complete json blobs. "Definitions" are an
+        # individual top-level key e.g. public_user_data. Filters are made of
+        # many definitions.
+
+        top_level_definitions = [
+            "public_user_data", "private_user_data", "server_data"
+        ]
+
+        room_level_definitions = [
+            "state", "events", "ephemeral"
+        ]
+
+        for key in top_level_definitions:
+            if key in user_filter:
+                self._check_definition(user_filter[key])
+
+        if "room" in user_filter:
+            for key in room_level_definitions:
+                if key in user_filter["room"]:
+                    self._check_definition(user_filter["room"][key])
+
+
+    def _check_definition(self, definition):
+        """Check if the provided definition is valid.
+
+        This inspects not only the types but also the values to make sure they
+        make sense.
+
+        Args:
+            definition(dict): The filter definition
+        Raises:
+            SynapseError: If there was a problem with this definition.
+        """
+        # NB: Filters are the complete json blobs. "Definitions" are an
+        # individual top-level key e.g. public_user_data. Filters are made of
+        # many definitions.
+        if type(definition) != dict:
+            raise SynapseError(
+                400, "Expected JSON object, not %s" % (definition,)
+            )
+
+        # check rooms are valid room IDs
+        room_id_keys = ["rooms", "not_rooms"]
+        for key in room_id_keys:
+            if key in definition:
+                if type(definition[key]) != list:
+                    raise SynapseError(400, "Expected %s to be a list." % key)
+                for room_id in definition[key]:
+                    RoomID.from_string(room_id)
+
+        # check senders are valid user IDs
+        user_id_keys = ["senders", "not_senders"]
+        for key in user_id_keys:
+            if key in definition:
+                if type(definition[key]) != list:
+                    raise SynapseError(400, "Expected %s to be a list." % key)
+                for user_id in definition[key]:
+                    UserID.from_string(user_id)
+
+        # TODO: We don't limit event type values but we probably should...
+        # check types are valid event types
+        event_keys = ["types", "not_types"]
+        for key in event_keys:
+            if key in definition:
+                if type(definition[key]) != list:
+                    raise SynapseError(400, "Expected %s to be a list." % key)
+                for event_type in definition[key]:
+                    if not isinstance(event_type, basestring):
+                        raise SynapseError(400, "Event type should be a string")
+
+        try:
+            event_format = definition["format"]
+            if event_format not in ["federation", "events"]:
+                raise SynapseError(400, "Invalid format: %s" % (event_format,))
+        except KeyError:
+            pass  # format is optional
+
+        try:
+            event_select_list = definition["select"]
+            for select_key in event_select_list:
+                if select_key not in ["event_id", "origin_server_ts",
+                                      "thread_id", "content", "content.body"]:
+                    raise SynapseError(400, "Bad select: %s" % (select_key,))
+        except KeyError:
+            pass  # select is optional
+
+        if ("bundle_updates" in definition and
+                type(definition["bundle_updates"]) != bool):
+            raise SynapseError(400, "Bad bundle_updates: expected bool.")
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 09e44e8ae0..81a3e95155 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -93,7 +93,7 @@ class CreateFilterRestServlet(RestServlet):
 
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=target_user.localpart,
-            definition=content,
+            user_filter=content,
         )
 
         defer.returnValue((200, {"filter_id": str(filter_id)}))
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index e98eaf8032..bab68a9eef 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -39,8 +39,8 @@ class FilteringStore(SQLBaseStore):
 
         defer.returnValue(json.loads(def_json))
 
-    def add_user_filter(self, user_localpart, definition):
-        def_json = json.dumps(definition)
+    def add_user_filter(self, user_localpart, user_filter):
+        def_json = json.dumps(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
         # INSERT a new one
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 149948374d..188fbfb91e 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -57,13 +57,21 @@ class FilteringTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_add_filter(self):
+        user_filter = {
+            "room": {
+                "state": {
+                    "types": ["m.*"]
+                }
+            }
+        }
+
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=user_localpart,
-            definition={"type": ["m.*"]},
+            user_filter=user_filter,
         )
 
         self.assertEquals(filter_id, 0)
-        self.assertEquals({"type": ["m.*"]},
+        self.assertEquals(user_filter,
             (yield self.datastore.get_user_filter(
                 user_localpart=user_localpart,
                 filter_id=0,
@@ -72,9 +80,17 @@ class FilteringTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_filter(self):
+        user_filter = {
+            "room": {
+                "state": {
+                    "types": ["m.*"]
+                }
+            }
+        }
+
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            definition={"type": ["m.*"]},
+            user_filter=user_filter,
         )
 
         filter = yield self.filtering.get_user_filter(
@@ -82,4 +98,4 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        self.assertEquals(filter, {"type": ["m.*"]})
+        self.assertEquals(filter, user_filter)