summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py174
-rw-r--r--synapse/handlers/search.py11
-rw-r--r--synapse/rest/client/v2_alpha/sync.py4
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/25/fts.py (renamed from synapse/storage/schema/delta/24/fts.py)5
-rw-r--r--tests/api/test_filtering.py57
6 files changed, 123 insertions, 130 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index e79e91e7eb..ab14b47281 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -24,7 +24,7 @@ class Filtering(object):
 
     def get_user_filter(self, user_localpart, filter_id):
         result = self.store.get_user_filter(user_localpart, filter_id)
-        result.addCallback(Filter)
+        result.addCallback(FilterCollection)
         return result
 
     def add_user_filter(self, user_localpart, user_filter):
@@ -131,125 +131,107 @@ class Filtering(object):
             raise SynapseError(400, "Bad bundle_updates: expected bool.")
 
 
-class Filter(object):
+class FilterCollection(object):
     def __init__(self, filter_json):
         self.filter_json = filter_json
 
+        self.room_timeline_filter = Filter(
+            self.filter_json.get("room", {}).get("timeline", {})
+        )
+
+        self.room_state_filter = Filter(
+            self.filter_json.get("room", {}).get("state", {})
+        )
+
+        self.room_ephemeral_filter = Filter(
+            self.filter_json.get("room", {}).get("ephemeral", {})
+        )
+
+        self.presence_filter = Filter(
+            self.filter_json.get("presence", {})
+        )
+
     def timeline_limit(self):
-        return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10)
+        return self.room_timeline_filter.limit()
 
     def presence_limit(self):
-        return self.filter_json.get("presence", {}).get("limit", 10)
+        return self.presence_filter.limit()
 
     def ephemeral_limit(self):
-        return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10)
+        return self.room_ephemeral_filter.limit()
 
     def filter_presence(self, events):
-        return self._filter_on_key(events, ["presence"])
+        return self.presence_filter.filter(events)
 
     def filter_room_state(self, events):
-        return self._filter_on_key(events, ["room", "state"])
+        return self.room_state_filter.filter(events)
 
     def filter_room_timeline(self, events):
-        return self._filter_on_key(events, ["room", "timeline"])
+        return self.room_timeline_filter.filter(events)
 
     def filter_room_ephemeral(self, events):
-        return self._filter_on_key(events, ["room", "ephemeral"])
-
-    def _filter_on_key(self, events, keys):
-        filter_json = self.filter_json
-        if not filter_json:
-            return events
-
-        try:
-            # extract the right definition from the filter
-            definition = filter_json
-            for key in keys:
-                definition = definition[key]
-            return self._filter_with_definition(events, definition)
-        except KeyError:
-            # return all events if definition isn't specified.
-            return events
-
-    def _filter_with_definition(self, events, definition):
-        return [e for e in events if self._passes_definition(definition, e)]
-
-    def _passes_definition(self, definition, event):
-        """Check if the event passes the filter definition
-        Args:
-            definition(dict): The filter definition to check against
-            event(dict or Event): The event to check
+        return self.room_ephemeral_filter.filter(events)
+
+
+class Filter(object):
+    def __init__(self, filter_json):
+        self.filter_json = filter_json
+
+    def check(self, event):
+        """Checks whether the filter matches the given event.
+
         Returns:
-            True if the event passes the filter in the definition
+            bool: True if the event matches
         """
-        if type(event) is dict:
-            room_id = event.get("room_id")
-            sender = event.get("sender")
-            event_type = event["type"]
-        else:
-            room_id = getattr(event, "room_id", None)
-            sender = getattr(event, "sender", None)
-            event_type = event.type
-        return self._event_passes_definition(
-            definition, room_id, sender, event_type
-        )
+        literal_keys = {
+            "rooms": lambda v: event.room_id == v,
+            "senders": lambda v: event.sender == v,
+            "types": lambda v: _matches_wildcard(event.type, v)
+        }
+
+        for name, match_func in literal_keys.items():
+            not_name = "not_%s" % (name,)
+            disallowed_values = self.filter_json.get(not_name, [])
+            if any(map(match_func, disallowed_values)):
+                return False
 
-    def _event_passes_definition(self, definition, room_id, sender,
-                                 event_type):
-        """Check if the event passes through the given definition.
+            allowed_values = self.filter_json.get(name, None)
+            if allowed_values is not None:
+                if not any(map(match_func, allowed_values)):
+                    return False
+
+        return True
+
+    def filter_rooms(self, room_ids):
+        """Apply the 'rooms' filter to a given list of rooms.
 
         Args:
-            definition(dict): The definition to check against.
-            room_id(str): The id of the room this event is in or None.
-            sender(str): The sender of the event
-            event_type(str): The type of the event.
+            room_ids (list): A list of room_ids.
+
         Returns:
-            True if the event passes through the filter.
+            list: A list of room_ids that match the filter
         """
-        # Algorithm notes:
-        # For each key in the definition, check the event meets the criteria:
-        #   * For types: Literal match or prefix match (if ends with wildcard)
-        #   * For senders/rooms: Literal match only
-        #   * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
-        #     and 'not_types' then it is treated as only being in 'not_types')
-
-        # room checks
-        if room_id is not None:
-            allow_rooms = definition.get("rooms", None)
-            reject_rooms = definition.get("not_rooms", None)
-            if reject_rooms and room_id in reject_rooms:
-                return False
-            if allow_rooms and room_id not in allow_rooms:
-                return False
+        room_ids = set(room_ids)
 
-        # sender checks
-        if sender is not None:
-            allow_senders = definition.get("senders", None)
-            reject_senders = definition.get("not_senders", None)
-            if reject_senders and sender in reject_senders:
-                return False
-            if allow_senders and sender not in allow_senders:
-                return False
+        disallowed_rooms = set(self.filter_json.get("not_rooms", []))
+        room_ids -= disallowed_rooms
 
-        # type checks
-        if "not_types" in definition:
-            for def_type in definition["not_types"]:
-                if self._event_matches_type(event_type, def_type):
-                    return False
-        if "types" in definition:
-            included = False
-            for def_type in definition["types"]:
-                if self._event_matches_type(event_type, def_type):
-                    included = True
-                    break
-            if not included:
-                return False
+        allowed_rooms = self.filter_json.get("rooms", None)
+        if allowed_rooms is not None:
+            room_ids &= set(allowed_rooms)
+
+        return room_ids
+
+    def filter(self, events):
+        return filter(self.check, events)
+
+    def limit(self):
+        return self.filter_json.get("limit", 10)
 
-        return True
 
-    def _event_matches_type(self, event_type, def_type):
-        if def_type.endswith("*"):
-            type_prefix = def_type[:-1]
-            return event_type.startswith(type_prefix)
-        else:
-            return event_type == def_type
+def _matches_wildcard(actual_value, filter_value):
+    if filter_value.endswith("*"):
+        type_prefix = filter_value[:-1]
+        return actual_value.startswith(type_prefix)
+    else:
+        return actual_value == filter_value
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 22808b9c07..f53e5d35ac 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from ._base import BaseHandler
 
 from synapse.api.constants import Membership
+from synapse.api.filtering import Filter
 from synapse.api.errors import SynapseError
 from synapse.events.utils import serialize_event
 
@@ -49,9 +50,12 @@ class SearchHandler(BaseHandler):
             keys = content["search_categories"]["room_events"].get("keys", [
                 "content.body", "content.name", "content.topic",
             ])
+            filter_dict = content["search_categories"]["room_events"].get("filter", {})
         except KeyError:
             raise SynapseError(400, "Invalid search query")
 
+        filtr = Filter(filter_dict)
+
         # TODO: Search through left rooms too
         rooms = yield self.store.get_rooms_for_user_where_membership_is(
             user.to_string(),
@@ -60,15 +64,16 @@ class SearchHandler(BaseHandler):
         )
         room_ids = set(r.room_id for r in rooms)
 
-        # TODO: Apply room filter to rooms list
+        room_ids = filtr.filter_rooms(room_ids)
 
         rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys)
 
+        filtered_events = filtr.filter(event_map.values())
+
         allowed_events = yield self._filter_events_for_client(
-            user.to_string(), event_map.values()
+            user.to_string(), filtered_events
         )
 
-        # TODO: Filter allowed_events
         # TODO: Add a limit
 
         time_now = self.clock.time_msec()
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 73473a7e6b..6c4f2b7cd4 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -23,7 +23,7 @@ from synapse.types import StreamToken
 from synapse.events.utils import (
     serialize_event, format_event_for_client_v2_without_event_id,
 )
-from synapse.api.filtering import Filter
+from synapse.api.filtering import FilterCollection
 from ._base import client_v2_pattern
 
 import copy
@@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet):
                 user.localpart, filter_id
             )
         except:
-            filter = Filter({})
+            filter = FilterCollection({})
 
         sync_config = SyncConfig(
             user=user,
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1ddf55be4d..1a74d6e360 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 24
+SCHEMA_VERSION = 25
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/25/fts.py
index 0c752d8426..ed3cc06557 100644
--- a/synapse/storage/schema/delta/24/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
 
 
 POSTGRES_SQL = """
-CREATE TABLE event_search (
+CREATE TABLE IF NOT EXISTS event_search (
     event_id TEXT,
     room_id TEXT,
     key TEXT,
@@ -53,7 +53,8 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id);
 
 
 SQLITE_TABLE = (
-    "CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)"
+    "CREATE VIRTUAL TABLE IF NOT EXISTS event_search"
+    " USING fts3 ( event_id, room_id, key, value)"
 )
 
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6942cdac51..9f9af2d783 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -23,10 +23,17 @@ from tests.utils import (
 )
 
 from synapse.types import UserID
-from synapse.api.filtering import Filter
+from synapse.api.filtering import FilterCollection, Filter
 
 user_localpart = "test_user"
-MockEvent = namedtuple("MockEvent", "sender type room_id")
+# MockEvent = namedtuple("MockEvent", "sender type room_id")
+
+
+def MockEvent(**kwargs):
+    ev = NonCallableMock(spec_set=kwargs.keys())
+    ev.configure_mock(**kwargs)
+    return ev
+
 
 class FilteringTestCase(unittest.TestCase):
 
@@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase):
         )
 
         self.filtering = hs.get_filtering()
-        self.filter = Filter({})
 
         self.datastore = hs.get_datastore()
 
@@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase):
             type="m.room.message",
             room_id="!foo:bar"
         )
+
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_types_works_with_wildcards(self):
@@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_types_works_with_unknowns(self):
@@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_works_with_literals(self):
@@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_works_with_wildcards(self):
@@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_works_with_unknowns(self):
@@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_types_takes_priority_over_types(self):
@@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_senders_works_with_literals(self):
@@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_senders_works_with_unknowns(self):
@@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_senders_works_with_literals(self):
@@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_senders_works_with_unknowns(self):
@@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_senders_takes_priority_over_senders(self):
@@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_rooms_works_with_literals(self):
@@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_rooms_works_with_unknowns(self):
@@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_rooms_works_with_literals(self):
@@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_rooms_works_with_unknowns(self):
@@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event(self):
@@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertTrue(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event_bad_sender(self):
@@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event_bad_room(self):
@@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!piggyshouse:muppets"  # nope
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     def test_definition_combined_event_bad_type(self):
@@ -341,7 +348,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filter._passes_definition(definition, event)
+            Filter(definition).check(event)
         )
 
     @defer.inlineCallbacks
@@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(
             sender="@foo:bar",
             type="m.profile",
-            room_id="!foo:bar"
         )
         events = [event]
 
@@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase):
         event = MockEvent(
             sender="@foo:bar",
             type="custom.avatar.3d.crazy",
-            room_id="!foo:bar"
         )
         events = [event]