summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-10-22 17:14:12 +0100
committerErik Johnston <erik@matrix.org>2015-10-22 17:14:12 +0100
commit53c679b59b664a248d4fce138806efe2e284c7e4 (patch)
treeabdddc0b5aadc1bd11b74c89cbf7af761d694885
parentMerge pull request #319 from matrix-org/erikj/filter_refactor (diff)
parentActually filter results (diff)
downloadsynapse-53c679b59b664a248d4fce138806efe2e284c7e4.tar.xz
Merge pull request #324 from matrix-org/erikj/search
Add filters to search.
-rw-r--r--synapse/api/filtering.py20
-rw-r--r--synapse/handlers/search.py15
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/25/fts.py (renamed from synapse/storage/schema/delta/24/fts.py)5
-rw-r--r--synapse/storage/search.py35
5 files changed, 61 insertions, 16 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 60b6648e0d..ab14b47281 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -202,6 +202,26 @@ class Filter(object):
 
         return True
 
+    def filter_rooms(self, room_ids):
+        """Apply the 'rooms' filter to a given list of rooms.
+
+        Args:
+            room_ids (list): A list of room_ids.
+
+        Returns:
+            list: A list of room_ids that match the filter
+        """
+        room_ids = set(room_ids)
+
+        disallowed_rooms = set(self.filter_json.get("not_rooms", []))
+        room_ids -= disallowed_rooms
+
+        allowed_rooms = self.filter_json.get("rooms", None)
+        if allowed_rooms is not None:
+            room_ids &= set(allowed_rooms)
+
+        return room_ids
+
     def filter(self, events):
         return filter(self.check, events)
 
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 22808b9c07..bbe82b1425 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from ._base import BaseHandler
 
 from synapse.api.constants import Membership
+from synapse.api.filtering import Filter
 from synapse.api.errors import SynapseError
 from synapse.events.utils import serialize_event
 
@@ -49,9 +50,12 @@ class SearchHandler(BaseHandler):
             keys = content["search_categories"]["room_events"].get("keys", [
                 "content.body", "content.name", "content.topic",
             ])
+            filter_dict = content["search_categories"]["room_events"].get("filter", {})
         except KeyError:
             raise SynapseError(400, "Invalid search query")
 
+        search_filter = Filter(filter_dict)
+
         # TODO: Search through left rooms too
         rooms = yield self.store.get_rooms_for_user_where_membership_is(
             user.to_string(),
@@ -60,15 +64,18 @@ class SearchHandler(BaseHandler):
         )
         room_ids = set(r.room_id for r in rooms)
 
-        # TODO: Apply room filter to rooms list
+        room_ids = search_filter.filter_rooms(room_ids)
+
+        rank_map, event_map, _ = yield self.store.search_msgs(
+            room_ids, search_term, keys
+        )
 
-        rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys)
+        filtered_events = search_filter.filter(event_map.values())
 
         allowed_events = yield self._filter_events_for_client(
-            user.to_string(), event_map.values()
+            user.to_string(), filtered_events
         )
 
-        # TODO: Filter allowed_events
         # TODO: Add a limit
 
         time_now = self.clock.time_msec()
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1ddf55be4d..1a74d6e360 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 24
+SCHEMA_VERSION = 25
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/schema/delta/24/fts.py b/synapse/storage/schema/delta/25/fts.py
index 0c752d8426..ed3cc06557 100644
--- a/synapse/storage/schema/delta/24/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
 
 
 POSTGRES_SQL = """
-CREATE TABLE event_search (
+CREATE TABLE IF NOT EXISTS event_search (
     event_id TEXT,
     room_id TEXT,
     key TEXT,
@@ -53,7 +53,8 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id);
 
 
 SQLITE_TABLE = (
-    "CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)"
+    "CREATE VIRTUAL TABLE IF NOT EXISTS event_search"
+    " USING fts3 ( event_id, room_id, key, value)"
 )
 
 
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index a3c69c5ab3..9608b5d6a7 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -18,6 +18,17 @@ from twisted.internet import defer
 from _base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
+from collections import namedtuple
+
+"""The result of a search.
+
+Fields:
+    rank_map (dict): Mapping event_id -> rank
+    event_map (dict): Mapping event_id -> event
+    pagination_token (str): Pagination token
+"""
+SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
+
 
 class SearchStore(SQLBaseStore):
     @defer.inlineCallbacks
@@ -31,15 +42,18 @@ class SearchStore(SQLBaseStore):
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            2-tuple of (dict event_id -> rank, dict event_id -> event)
+            SearchResult
         """
         clauses = []
         args = []
 
-        clauses.append(
-            "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
-        )
-        args.extend(room_ids)
+        # Make sure we don't explode because the person is in too many rooms.
+        # We filter the results below regardless.
+        if len(room_ids) < 500:
+            clauses.append(
+                "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
+            )
+            args.extend(room_ids)
 
         local_clauses = []
         for key in keys:
@@ -52,13 +66,13 @@ class SearchStore(SQLBaseStore):
 
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
-                "SELECT ts_rank_cd(vector, query) AS rank, event_id"
+                "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
                 " FROM plainto_tsquery('english', ?) as query, event_search"
                 " WHERE vector @@ query"
             )
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
-                "SELECT 0 as rank, event_id FROM event_search"
+                "SELECT 0 as rank, room_id, event_id FROM event_search"
                 " WHERE value MATCH ?"
             )
         else:
@@ -76,6 +90,8 @@ class SearchStore(SQLBaseStore):
             "search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
         )
 
+        results = filter(lambda row: row["room_id"] in room_ids, results)
+
         events = yield self._get_events([r["event_id"] for r in results])
 
         event_map = {
@@ -83,11 +99,12 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        defer.returnValue((
+        defer.returnValue(SearchResult(
             {
                 r["event_id"]: r["rank"]
                 for r in results
                 if r["event_id"] in event_map
             },
-            event_map
+            event_map,
+            None
         ))