summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorSteven Hammerton <steven.hammerton@openmarket.com>2015-11-05 20:59:45 +0000
committerSteven Hammerton <steven.hammerton@openmarket.com>2015-11-05 20:59:45 +0000
commitfece2f5c771ca4baab8a0541ede8af2e33ff1b41 (patch)
tree22973a8a9a9b806cb09d734e19cb5150686e9869 /synapse/storage
parentAllow hs to do CAS login completely and issue the client with a login token t... (diff)
parentMerge pull request #350 from matrix-org/erikj/search (diff)
downloadsynapse-fece2f5c771ca4baab8a0541ede8af2e33ff1b41.tar.xz
Merge branch 'develop' into sh-cas-auth-via-homeserver
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py2
-rw-r--r--synapse/storage/room.py13
-rw-r--r--synapse/storage/schema/delta/25/history_visibility.sql26
-rw-r--r--synapse/storage/search.py120
-rw-r--r--synapse/storage/stream.py46
5 files changed, 182 insertions, 25 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index e6c1abfc27..59c9987202 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
                 self._store_room_message_txn(txn, event)
             elif event.type == EventTypes.Redaction:
                 self._store_redaction(txn, event)
+            elif event.type == EventTypes.RoomHistoryVisibility:
+                self._store_history_visibility_txn(txn, event)
 
         self._store_room_members_txn(
             txn,
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 13441fcdce..1c79626736 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
                 txn, event, "content.body", event.content["body"]
             )
 
+    def _store_history_visibility_txn(self, txn, event):
+        if hasattr(event, "content") and "history_visibility" in event.content:
+            sql = (
+                "INSERT INTO history_visibility"
+                " (event_id, room_id, history_visibility)"
+                " VALUES (?, ?, ?)"
+            )
+            txn.execute(sql, (
+                event.event_id,
+                event.room_id,
+                event.content["history_visibility"]
+            ))
+
     def _store_event_search_txn(self, txn, event, key, value):
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql
new file mode 100644
index 0000000000..9f387ed69f
--- /dev/null
+++ b/synapse/storage/schema/delta/25/history_visibility.sql
@@ -0,0 +1,26 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This is a manual index of history_visibility content of state events,
+ * so that we can join on them in SELECT statements.
+ */
+CREATE TABLE IF NOT EXISTS history_visibility(
+    id INTEGER PRIMARY KEY,
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    history_visibility TEXT NOT NULL,
+    UNIQUE (event_id)
+);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index cdf003502f..3cea2011fa 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -16,18 +16,13 @@
 from twisted.internet import defer
 
 from _base import SQLBaseStore
+from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-from collections import namedtuple
+import logging
 
-"""The result of a search.
 
-Fields:
-    rank_map (dict): Mapping event_id -> rank
-    event_map (dict): Mapping event_id -> event
-    pagination_token (str): Pagination token
-"""
-SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
+logger = logging.getLogger(__name__)
 
 
 class SearchStore(SQLBaseStore):
@@ -42,7 +37,7 @@ class SearchStore(SQLBaseStore):
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            SearchResult
+            list of dicts
         """
         clauses = []
         args = []
@@ -100,12 +95,103 @@ class SearchStore(SQLBaseStore):
             for ev in events
         }
 
-        defer.returnValue(SearchResult(
+        defer.returnValue([
             {
-                r["event_id"]: r["rank"]
-                for r in results
-                if r["event_id"] in event_map
-            },
-            event_map,
-            None
-        ))
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])
+
+    @defer.inlineCallbacks
+    def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
+        """Performs a full text search over events with given keys.
+
+        Args:
+            room_id (str): The room_id to search in
+            search_term (str): Search term to search for
+            keys (list): List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            pagination_token (str): A pagination token previously returned
+
+        Returns:
+            list of dicts
+        """
+        clauses = []
+        args = [search_term, room_id]
+
+        local_clauses = []
+        for key in keys:
+            local_clauses.append("key = ?")
+            args.append(key)
+
+        clauses.append(
+            "(%s)" % (" OR ".join(local_clauses),)
+        )
+
+        if pagination_token:
+            try:
+                topo, stream = pagination_token.split(",")
+                topo = int(topo)
+                stream = int(stream)
+            except:
+                raise SynapseError(400, "Invalid pagination token")
+
+            clauses.append(
+                "(topological_ordering < ?"
+                " OR (topological_ordering = ? AND stream_ordering < ?))"
+            )
+            args.extend([topo, topo, stream])
+
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "SELECT ts_rank_cd(vector, query) as rank,"
+                " topological_ordering, stream_ordering, room_id, event_id"
+                " FROM plainto_tsquery('english', ?) as query, event_search"
+                " NATURAL JOIN events"
+                " WHERE vector @@ query AND room_id = ?"
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
+                " topological_ordering, stream_ordering"
+                " FROM event_search"
+                " NATURAL JOIN events"
+                " WHERE value MATCH ? AND room_id = ?"
+            )
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        for clause in clauses:
+            sql += " AND " + clause
+
+        # We add an arbitrary limit here to ensure we don't try to pull the
+        # entire table from the database.
+        sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
+
+        args.append(limit)
+
+        results = yield self._execute(
+            "search_rooms", self.cursor_to_dict, sql, *args
+        )
+
+        events = yield self._get_events([r["event_id"] for r in results])
+
+        event_map = {
+            ev.event_id: ev
+            for ev in events
+        }
+
+        defer.returnValue([
+            {
+                "event": event_map[r["event_id"]],
+                "rank": r["rank"],
+                "pagination_token": "%s,%s" % (
+                    r["topological_ordering"], r["stream_ordering"]
+                ),
+            }
+            for r in results
+            if r["event_id"] in event_map
+        ])
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index c728013f4c..be8ba76aae 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore):
         defer.returnValue(results)
 
     @log_function
-    def get_room_events_stream(self, user_id, from_key, to_key, limit=0):
-        current_room_membership_sql = (
-            "SELECT m.room_id FROM room_memberships as m "
-            " INNER JOIN current_state_events as c"
-            " ON m.event_id = c.event_id AND c.state_key = m.user_id"
-            " WHERE m.user_id = ? AND m.membership = 'join'"
-        )
+    def get_room_events_stream(
+        self,
+        user_id,
+        from_key,
+        to_key,
+        limit=0,
+        is_guest=False,
+        room_ids=None
+    ):
+        room_ids = room_ids or []
+        room_ids = [r for r in room_ids]
+        if is_guest:
+            current_room_membership_sql = (
+                "SELECT c.room_id FROM history_visibility AS h"
+                " INNER JOIN current_state_events AS c"
+                " ON h.event_id = c.event_id"
+                " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+                    ",".join(map(lambda _: "?", room_ids))
+                )
+            )
+            current_room_membership_args = room_ids
+        else:
+            current_room_membership_sql = (
+                "SELECT m.room_id FROM room_memberships as m "
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id AND c.state_key = m.user_id"
+                " WHERE m.user_id = ? AND m.membership = 'join'"
+            )
+            current_room_membership_args = [user_id]
+            if room_ids:
+                current_room_membership_sql += " AND m.room_id in (%s)" % (
+                    ",".join(map(lambda _: "?", room_ids))
+                )
+                current_room_membership_args = [user_id] + room_ids
 
         # We also want to get any membership events about that user, e.g.
         # invites or leave notifications.
@@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore):
             "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
             "WHERE m.user_id = ? "
         )
+        membership_args = [user_id]
 
         if limit:
             limit = max(limit, MAX_STREAM_SIZE)
@@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore):
         }
 
         def f(txn):
-            txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
+            args = ([False] + current_room_membership_args + membership_args +
+                    [from_id.stream, to_id.stream])
+            txn.execute(sql, args)
 
             rows = self.cursor_to_dict(txn)