summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xscripts/synapse_port_db123
-rw-r--r--synapse/config/jwt.py2
-rw-r--r--synapse/rest/client/v1/login.py9
-rw-r--r--synapse/storage/background_updates.py3
-rw-r--r--synapse/storage/room.py16
-rw-r--r--synapse/storage/schema/delta/31/search_update.py65
-rw-r--r--synapse/storage/search.py92
7 files changed, 258 insertions, 52 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 253a6ef6c7..efd04da2d6 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -214,6 +214,10 @@ class Porter(object):
 
         self.progress.add_table(table, postgres_size, table_size)
 
+        if table == "event_search":
+            yield self.handle_search_table(postgres_size, table_size, next_chunk)
+            return
+
         select = (
             "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
             % (table,)
@@ -232,60 +236,95 @@ class Porter(object):
             if rows:
                 next_chunk = rows[-1][0] + 1
 
-                if table == "event_search":
-                    # We have to treat event_search differently since it has a
-                    # different structure in the two different databases.
-                    def insert(txn):
-                        sql = (
-                            "INSERT INTO event_search (event_id, room_id, key, sender, vector)"
-                            " VALUES (?,?,?,?,to_tsvector('english', ?))"
-                        )
+                self._convert_rows(table, headers, rows)
 
-                        rows_dict = [
-                            dict(zip(headers, row))
-                            for row in rows
-                        ]
-
-                        txn.executemany(sql, [
-                            (
-                                row["event_id"],
-                                row["room_id"],
-                                row["key"],
-                                row["sender"],
-                                row["value"],
-                            )
-                            for row in rows_dict
-                        ])
-
-                        self.postgres_store._simple_update_one_txn(
-                            txn,
-                            table="port_from_sqlite3",
-                            keyvalues={"table_name": table},
-                            updatevalues={"rowid": next_chunk},
-                        )
-                else:
-                    self._convert_rows(table, headers, rows)
+                def insert(txn):
+                    self.postgres_store.insert_many_txn(
+                        txn, table, headers[1:], rows
+                    )
 
-                    def insert(txn):
-                        self.postgres_store.insert_many_txn(
-                            txn, table, headers[1:], rows
-                        )
+                    self.postgres_store._simple_update_one_txn(
+                        txn,
+                        table="port_from_sqlite3",
+                        keyvalues={"table_name": table},
+                        updatevalues={"rowid": next_chunk},
+                    )
+
+                yield self.postgres_store.execute(insert)
+
+                postgres_size += len(rows)
+
+                self.progress.update(table, postgres_size)
+            else:
+                return
+
+    @defer.inlineCallbacks
+    def handle_search_table(self, postgres_size, table_size, next_chunk):
+        select = (
+            "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
+            " FROM event_search as es"
+            " INNER JOIN events AS e USING (event_id, room_id)"
+            " WHERE es.rowid >= ?"
+            " ORDER BY es.rowid LIMIT ?"
+        )
 
-                        self.postgres_store._simple_update_one_txn(
-                            txn,
-                            table="port_from_sqlite3",
-                            keyvalues={"table_name": table},
-                            updatevalues={"rowid": next_chunk},
+        while True:
+            def r(txn):
+                txn.execute(select, (next_chunk, self.batch_size,))
+                rows = txn.fetchall()
+                headers = [column[0] for column in txn.description]
+
+                return headers, rows
+
+            headers, rows = yield self.sqlite_store.runInteraction("select", r)
+
+            if rows:
+                next_chunk = rows[-1][0] + 1
+
+                # We have to treat event_search differently since it has a
+                # different structure in the two different databases.
+                def insert(txn):
+                    sql = (
+                        "INSERT INTO event_search (event_id, room_id, key,"
+                        " sender, vector, origin_server_ts, stream_ordering)"
+                        " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
+                    )
+
+                    rows_dict = [
+                        dict(zip(headers, row))
+                        for row in rows
+                    ]
+
+                    txn.executemany(sql, [
+                        (
+                            row["event_id"],
+                            row["room_id"],
+                            row["key"],
+                            row["sender"],
+                            row["value"],
+                            row["origin_server_ts"],
+                            row["stream_ordering"],
                         )
+                        for row in rows_dict
+                    ])
+
+                    self.postgres_store._simple_update_one_txn(
+                        txn,
+                        table="port_from_sqlite3",
+                        keyvalues={"table_name": "event_search"},
+                        updatevalues={"rowid": next_chunk},
+                    )
 
                 yield self.postgres_store.execute(insert)
 
                 postgres_size += len(rows)
 
-                self.progress.update(table, postgres_size)
+                self.progress.update("event_search", postgres_size)
+
             else:
                 return
 
+
     def setup_db(self, db_config, database_engine):
         db_conn = database_engine.module.connect(
             **{
diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py
index 4cb092bbec..5c8199612b 100644
--- a/synapse/config/jwt.py
+++ b/synapse/config/jwt.py
@@ -30,6 +30,8 @@ class JWTConfig(Config):
 
     def default_config(self, **kwargs):
         return """\
+        # The JWT needs to contain a globally unique "sub" (subject) claim.
+        #
         # jwt_config:
         #    enabled: true
         #    secret: "a secret"
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d14ce3efa2..166a78026a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -224,16 +224,19 @@ class LoginRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def do_jwt_login(self, login_submission):
-        token = login_submission['token']
+        token = login_submission.get("token", None)
         if token is None:
-            raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+            raise LoginError(401, "Token field for JWT is missing",
+                             errcode=Codes.UNAUTHORIZED)
 
         try:
             payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+        except jwt.ExpiredSignatureError:
+            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
         except InvalidTokenError:
             raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
-        user = payload['user']
+        user = payload.get("sub", None)
         if user is None:
             raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 49904046cf..66a995157d 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -173,11 +173,12 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         logger.info(
             "Updating %r. Updated %r items in %rms."
-            " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
+            " (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
             update_name, items_updated, duration_ms,
             performance.total_items_per_ms(),
             performance.average_items_per_ms(),
             performance.total_item_count,
+            batch_size,
         )
 
         performance.update(items_updated, duration_ms)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 9be977f387..70aa64fb31 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -169,20 +169,28 @@ class RoomStore(SQLBaseStore):
     def _store_event_search_txn(self, txn, event, key, value):
         if isinstance(self.database_engine, PostgresEngine):
             sql = (
-                "INSERT INTO event_search (event_id, room_id, key, vector)"
-                " VALUES (?,?,?,to_tsvector('english', ?))"
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+            txn.execute(
+                sql,
+                (
+                    event.event_id, event.room_id, key, value,
+                    event.internal_metadata.stream_ordering,
+                    event.origin_server_ts,
+                )
             )
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
                 "INSERT INTO event_search (event_id, room_id, key, value)"
                 " VALUES (?,?,?,?)"
             )
+            txn.execute(sql, (event.event_id, event.room_id, key, value,))
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        txn.execute(sql, (event.event_id, event.room_id, key, value,))
-
     @cachedInlineCallbacks()
     def get_room_name_and_aliases(self, room_id):
         def f(txn):
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
new file mode 100644
index 0000000000..470ae0c005
--- /dev/null
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -0,0 +1,65 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+
+ALTER_TABLE = """
+ALTER TABLE event_search ADD COLUMN origin_server_ts BIGINT;
+ALTER TABLE event_search ADD COLUMN stream_ordering BIGINT;
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if not isinstance(database_engine, PostgresEngine):
+        return
+
+    for statement in get_statements(ALTER_TABLE.splitlines()):
+        cur.execute(statement)
+
+    cur.execute("SELECT MIN(stream_ordering) FROM events")
+    rows = cur.fetchall()
+    min_stream_id = rows[0][0]
+
+    cur.execute("SELECT MAX(stream_ordering) FROM events")
+    rows = cur.fetchall()
+    max_stream_id = rows[0][0]
+
+    if min_stream_id is not None and max_stream_id is not None:
+        progress = {
+            "target_min_stream_id_inclusive": min_stream_id,
+            "max_stream_id_exclusive": max_stream_id + 1,
+            "rows_inserted": 0,
+            "have_added_indexes": False,
+        }
+        progress_json = ujson.dumps(progress)
+
+        sql = (
+            "INSERT into background_updates (update_name, progress_json)"
+            " VALUES (?, ?)"
+        )
+
+        sql = database_engine.convert_param_style(sql)
+
+        cur.execute(sql, ("event_search_order", progress_json))
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+    pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 59ac7f424c..0224299625 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -29,12 +29,17 @@ logger = logging.getLogger(__name__)
 class SearchStore(BackgroundUpdateStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
+    EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
 
     def __init__(self, hs):
         super(SearchStore, self).__init__(hs)
         self.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
+        self.register_background_update_handler(
+            self.EVENT_SEARCH_ORDER_UPDATE_NAME,
+            self._background_reindex_search_order
+        )
 
     @defer.inlineCallbacks
     def _background_reindex_search(self, progress, batch_size):
@@ -132,6 +137,82 @@ class SearchStore(BackgroundUpdateStore):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
+    def _background_reindex_search_order(self, progress, batch_size):
+        target_min_stream_id = progress["target_min_stream_id_inclusive"]
+        max_stream_id = progress["max_stream_id_exclusive"]
+        rows_inserted = progress.get("rows_inserted", 0)
+        have_added_index = progress['have_added_indexes']
+
+        if not have_added_index:
+            def create_index(conn):
+                conn.rollback()
+                conn.set_session(autocommit=True)
+                c = conn.cursor()
+
+                # We create with NULLS FIRST so that when we search *backwards*
+                # we get the ones with non null origin_server_ts *first*
+                c.execute(
+                    "CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
+                    "room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+                )
+                c.execute(
+                    "CREATE INDEX CONCURRENTLY event_search_order ON event_search("
+                    "origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
+                )
+                conn.set_session(autocommit=False)
+
+            yield self.runWithConnection(create_index)
+
+            pg = dict(progress)
+            pg["have_added_indexes"] = True
+
+            yield self.runInteraction(
+                self.EVENT_SEARCH_ORDER_UPDATE_NAME,
+                self._background_update_progress_txn,
+                self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
+            )
+
+        def reindex_search_txn(txn):
+            sql = (
+                "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
+                " origin_server_ts = e.origin_server_ts"
+                " FROM events AS e"
+                " WHERE e.event_id = es.event_id"
+                " AND ? <= e.stream_ordering AND e.stream_ordering < ?"
+                " RETURNING es.stream_ordering"
+            )
+
+            min_stream_id = max_stream_id - batch_size
+            txn.execute(sql, (min_stream_id, max_stream_id))
+            rows = txn.fetchall()
+
+            if min_stream_id < target_min_stream_id:
+                # We've recached the end.
+                return len(rows), False
+
+            progress = {
+                "target_min_stream_id_inclusive": target_min_stream_id,
+                "max_stream_id_exclusive": min_stream_id,
+                "rows_inserted": rows_inserted + len(rows),
+                "have_added_indexes": True,
+            }
+
+            self._background_update_progress_txn(
+                txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
+            )
+
+            return len(rows), True
+
+        num_rows, finished = yield self.runInteraction(
+            self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
+        )
+
+        if not finished:
+            yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
+
+        defer.returnValue(num_rows)
+
+    @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
 
@@ -310,7 +391,6 @@ class SearchStore(BackgroundUpdateStore):
                 "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
                 " origin_server_ts, stream_ordering, room_id, event_id"
                 " FROM event_search"
-                " NATURAL JOIN events"
                 " WHERE vector @@ to_tsquery('english', ?) AND "
             )
             args = [search_query, search_query] + args
@@ -355,7 +435,15 @@ class SearchStore(BackgroundUpdateStore):
 
         # We add an arbitrary limit here to ensure we don't try to pull the
         # entire table from the database.
-        sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
+        if isinstance(self.database_engine, PostgresEngine):
+            sql += (
+                " ORDER BY origin_server_ts DESC NULLS LAST,"
+                " stream_ordering DESC NULLS LAST LIMIT ?"
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
+        else:
+            raise Exception("Unrecognized database engine")
 
         args.append(limit)