summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-01-29 17:41:48 +0000
committerMark Haines <mark.haines@matrix.org>2015-01-29 17:45:07 +0000
commit93ed31dda2e23742c3d7f3eee6ac6839682f0ce9 (patch)
tree185ebc6da02679b3bd3f7f9dd7c405624f925052
parentMerge branch 'develop' into client_v2_filter (diff)
downloadsynapse-93ed31dda2e23742c3d7f3eee6ac6839682f0ce9.tar.xz
Create a separate filter object to do the actual filtering, so that we can
split the storage and management of filters from the actual filter code
and don't have to load a filter from the db each time we filter an event
-rw-r--r--synapse/api/filtering.py220
-rw-r--r--synapse/rest/client/v2_alpha/filter.py2
-rw-r--r--tests/api/test_filtering.py108
3 files changed, 166 insertions, 164 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index e16c0e559f..b7e5d3222f 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -25,127 +25,25 @@ class Filtering(object):
         self.store = hs.get_datastore()
 
     def get_user_filter(self, user_localpart, filter_id):
-        return self.store.get_user_filter(user_localpart, filter_id)
+        result = self.store.get_user_filter(user_localpart, filter_id)
+        result.addCallback(Filter)
+        return result
 
     def add_user_filter(self, user_localpart, user_filter):
         self._check_valid_filter(user_filter)
         return self.store.add_user_filter(user_localpart, user_filter)
 
-    def filter_public_user_data(self, events, user, filter_id):
-        return self._filter_on_key(
-            events, user, filter_id, ["public_user_data"]
-        )
-
-    def filter_private_user_data(self, events, user, filter_id):
-        return self._filter_on_key(
-            events, user, filter_id, ["private_user_data"]
-        )
-
-    def filter_room_state(self, events, user, filter_id):
-        return self._filter_on_key(
-            events, user, filter_id, ["room", "state"]
-        )
-
-    def filter_room_events(self, events, user, filter_id):
-        return self._filter_on_key(
-            events, user, filter_id, ["room", "events"]
-        )
-
-    def filter_room_ephemeral(self, events, user, filter_id):
-        return self._filter_on_key(
-            events, user, filter_id, ["room", "ephemeral"]
-        )
-
     # TODO(paul): surely we should probably add a delete_user_filter or
     #   replace_user_filter at some point? There's no REST API specified for
     #   them however
 
-    @defer.inlineCallbacks
-    def _filter_on_key(self, events, user, filter_id, keys):
-        filter_json = yield self.get_user_filter(user.localpart, filter_id)
-        if not filter_json:
-            defer.returnValue(events)
-
-        try:
-            # extract the right definition from the filter
-            definition = filter_json
-            for key in keys:
-                definition = definition[key]
-            defer.returnValue(self._filter_with_definition(events, definition))
-        except KeyError:
-            # return all events if definition isn't specified.
-            defer.returnValue(events)
-
-    def _filter_with_definition(self, events, definition):
-        return [e for e in events if self._passes_definition(definition, e)]
-
-    def _passes_definition(self, definition, event):
-        """Check if the event passes through the given definition.
-
-        Args:
-            definition(dict): The definition to check against.
-            event(Event): The event to check.
-        Returns:
-            True if the event passes through the filter.
-        """
-        # Algorithm notes:
-        # For each key in the definition, check the event meets the criteria:
-        #   * For types: Literal match or prefix match (if ends with wildcard)
-        #   * For senders/rooms: Literal match only
-        #   * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
-        #     and 'not_types' then it is treated as only being in 'not_types')
-
-        # room checks
-        if hasattr(event, "room_id"):
-            room_id = event.room_id
-            allow_rooms = definition.get("rooms", None)
-            reject_rooms = definition.get("not_rooms", None)
-            if reject_rooms and room_id in reject_rooms:
-                return False
-            if allow_rooms and room_id not in allow_rooms:
-                return False
-
-        # sender checks
-        if hasattr(event, "sender"):
-            # Should we be including event.state_key for some event types?
-            sender = event.sender
-            allow_senders = definition.get("senders", None)
-            reject_senders = definition.get("not_senders", None)
-            if reject_senders and sender in reject_senders:
-                return False
-            if allow_senders and sender not in allow_senders:
-                return False
-
-        # type checks
-        if "not_types" in definition:
-            for def_type in definition["not_types"]:
-                if self._event_matches_type(event, def_type):
-                    return False
-        if "types" in definition:
-            included = False
-            for def_type in definition["types"]:
-                if self._event_matches_type(event, def_type):
-                    included = True
-                    break
-            if not included:
-                return False
-
-        return True
-
-    def _event_matches_type(self, event, def_type):
-        if def_type.endswith("*"):
-            type_prefix = def_type[:-1]
-            return event.type.startswith(type_prefix)
-        else:
-            return event.type == def_type
-
-    def _check_valid_filter(self, user_filter):
+    def _check_valid_filter(self, user_filter_json):
         """Check if the provided filter is valid.
 
         This inspects all definitions contained within the filter.
 
         Args:
-            user_filter(dict): The filter
+            user_filter_json(dict): The filter
         Raises:
             SynapseError: If the filter is not valid.
         """
@@ -162,13 +60,13 @@ class Filtering(object):
         ]
 
         for key in top_level_definitions:
-            if key in user_filter:
-                self._check_definition(user_filter[key])
+            if key in user_filter_json:
+                self._check_definition(user_filter_json[key])
 
-        if "room" in user_filter:
+        if "room" in user_filter_json:
             for key in room_level_definitions:
-                if key in user_filter["room"]:
-                    self._check_definition(user_filter["room"][key])
+                if key in user_filter_json["room"]:
+                    self._check_definition(user_filter_json["room"][key])
 
     def _check_definition(self, definition):
         """Check if the provided definition is valid.
@@ -237,3 +135,101 @@ class Filtering(object):
         if ("bundle_updates" in definition and
                 type(definition["bundle_updates"]) != bool):
             raise SynapseError(400, "Bad bundle_updates: expected bool.")
+
+
+class Filter(object):
+    def __init__(self, filter_json):
+        self.filter_json = filter_json
+
+    def filter_public_user_data(self, events):
+        return self._filter_on_key(events, ["public_user_data"])
+
+    def filter_private_user_data(self, events):
+        return self._filter_on_key(events, ["private_user_data"])
+
+    def filter_room_state(self, events):
+        return self._filter_on_key(events, ["room", "state"])
+
+    def filter_room_events(self, events):
+        return self._filter_on_key(events, ["room", "events"])
+
+    def filter_room_ephemeral(self, events):
+        return self._filter_on_key(events, ["room", "ephemeral"])
+
+    def _filter_on_key(self, events, keys):
+        filter_json = self.filter_json
+        if not filter_json:
+            return events
+
+        try:
+            # extract the right definition from the filter
+            definition = filter_json
+            for key in keys:
+                definition = definition[key]
+            return self._filter_with_definition(events, definition)
+        except KeyError:
+            # return all events if definition isn't specified.
+            return events
+
+    def _filter_with_definition(self, events, definition):
+        return [e for e in events if self._passes_definition(definition, e)]
+
+    def _passes_definition(self, definition, event):
+        """Check if the event passes through the given definition.
+
+        Args:
+            definition(dict): The definition to check against.
+            event(Event): The event to check.
+        Returns:
+            True if the event passes through the filter.
+        """
+        # Algorithm notes:
+        # For each key in the definition, check the event meets the criteria:
+        #   * For types: Literal match or prefix match (if ends with wildcard)
+        #   * For senders/rooms: Literal match only
+        #   * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
+        #     and 'not_types' then it is treated as only being in 'not_types')
+
+        # room checks
+        if hasattr(event, "room_id"):
+            room_id = event.room_id
+            allow_rooms = definition.get("rooms", None)
+            reject_rooms = definition.get("not_rooms", None)
+            if reject_rooms and room_id in reject_rooms:
+                return False
+            if allow_rooms and room_id not in allow_rooms:
+                return False
+
+        # sender checks
+        if hasattr(event, "sender"):
+            # Should we be including event.state_key for some event types?
+            sender = event.sender
+            allow_senders = definition.get("senders", None)
+            reject_senders = definition.get("not_senders", None)
+            if reject_senders and sender in reject_senders:
+                return False
+            if allow_senders and sender not in allow_senders:
+                return False
+
+        # type checks
+        if "not_types" in definition:
+            for def_type in definition["not_types"]:
+                if self._event_matches_type(event, def_type):
+                    return False
+        if "types" in definition:
+            included = False
+            for def_type in definition["types"]:
+                if self._event_matches_type(event, def_type):
+                    included = True
+                    break
+            if not included:
+                return False
+
+        return True
+
+    def _event_matches_type(self, event, def_type):
+        if def_type.endswith("*"):
+            type_prefix = def_type[:-1]
+            return event.type.startswith(type_prefix)
+        else:
+            return event.type == def_type
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index cee06ccaca..6ddc495d23 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet):
                 filter_id=filter_id,
             )
 
-            defer.returnValue((200, filter))
+            defer.returnValue((200, filter.filter_json))
         except KeyError:
             raise SynapseError(400, "No such filter")
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index aa93616a9f..babf4c37f1 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -25,6 +25,7 @@ from tests.utils import (
 
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.api.filtering import Filter
 
 user_localpart = "test_user"
 MockEvent = namedtuple("MockEvent", "sender type room_id")
@@ -53,6 +54,7 @@ class FilteringTestCase(unittest.TestCase):
         )
 
         self.filtering = hs.get_filtering()
+        self.filter = Filter({})
 
         self.datastore = hs.get_datastore()
 
@@ -66,7 +68,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_types_works_with_wildcards(self):
@@ -79,7 +81,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_types_works_with_unknowns(self):
@@ -92,7 +94,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_works_with_literals(self):
@@ -105,7 +107,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_works_with_wildcards(self):
@@ -118,7 +120,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_works_with_unknowns(self):
@@ -131,7 +133,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_takes_priority_over_types(self):
@@ -145,7 +147,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_senders_works_with_literals(self):
@@ -158,7 +160,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_senders_works_with_unknowns(self):
@@ -171,7 +173,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_senders_works_with_literals(self):
@@ -184,7 +186,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_senders_works_with_unknowns(self):
@@ -197,7 +199,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_senders_takes_priority_over_senders(self):
@@ -211,7 +213,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_rooms_works_with_literals(self):
@@ -224,7 +226,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_rooms_works_with_unknowns(self):
@@ -237,7 +239,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_rooms_works_with_literals(self):
@@ -250,7 +252,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_rooms_works_with_unknowns(self):
@@ -263,7 +265,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -277,7 +279,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event(self):
@@ -295,7 +297,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event_bad_sender(self):
@@ -313,7 +315,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event_bad_room(self):
@@ -331,7 +333,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!piggyshouse:muppets"  # nope
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event_bad_type(self):
@@ -349,12 +351,12 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     @defer.inlineCallbacks
     def test_filter_public_user_data_match(self):
-        user_filter = {
+        user_filter_json = {
             "public_user_data": {
                 "types": ["m.*"]
             }
@@ -362,7 +364,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -371,16 +373,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_public_user_data(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_public_user_data(events=events)
         self.assertEquals(events, results)
 
     @defer.inlineCallbacks
     def test_filter_public_user_data_no_match(self):
-        user_filter = {
+        user_filter_json = {
             "public_user_data": {
                 "types": ["m.*"]
             }
@@ -388,7 +391,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -397,16 +400,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_public_user_data(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_public_user_data(events=events)
         self.assertEquals([], results)
 
     @defer.inlineCallbacks
     def test_filter_room_state_match(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -416,7 +420,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -425,16 +429,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_room_state(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_room_state(events=events)
         self.assertEquals(events, results)
 
     @defer.inlineCallbacks
     def test_filter_room_state_no_match(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -444,7 +449,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -453,16 +458,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_room_state(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_room_state(events)
         self.assertEquals([], results)
 
     @defer.inlineCallbacks
     def test_add_filter(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -472,11 +478,11 @@ class FilteringTestCase(unittest.TestCase):
 
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
 
         self.assertEquals(filter_id, 0)
-        self.assertEquals(user_filter,
+        self.assertEquals(user_filter_json,
             (yield self.datastore.get_user_filter(
                 user_localpart=user_localpart,
                 filter_id=0,
@@ -485,7 +491,7 @@ class FilteringTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_filter(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -495,7 +501,7 @@ class FilteringTestCase(unittest.TestCase):
 
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
 
         filter = yield self.filtering.get_user_filter(
@@ -503,4 +509,4 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        self.assertEquals(filter, user_filter)
+        self.assertEquals(filter.filter_json, user_filter_json)