summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/snapshot.py16
-rw-r--r--synapse/state.py33
-rw-r--r--synapse/storage/events.py49
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/35/state.sql21
-rw-r--r--synapse/storage/state.py221
6 files changed, 257 insertions, 85 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e895b1c450..ec32008d5a 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,9 +15,25 @@
 
 
 class EventContext(object):
+    __slots__ = [
+        "current_state_ids",
+        "prev_state_ids",
+        "state_group",
+        "rejected",
+        "push_actions",
+        "prev_group",
+        "delta_ids",
+        "prev_state_events",
+    ]
+
     def __init__(self):
         self.current_state_ids = None
         self.prev_state_ids = None
         self.state_group = None
         self.rejected = False
         self.push_actions = []
+
+        self.prev_group = None
+        self.delta_ids = None
+
+        self.prev_state_events = None
diff --git a/synapse/state.py b/synapse/state.py
index cd792afed1..4520fa0415 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -55,12 +55,15 @@ def _gen_state_id():
 
 
 class _StateCacheEntry(object):
-    __slots__ = ["state", "state_group", "state_id"]
+    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
-    def __init__(self, state, state_group):
+    def __init__(self, state, state_group, prev_group=None, delta_ids=None):
         self.state = state
         self.state_group = state_group
 
+        self.prev_group = prev_group
+        self.delta_ids = delta_ids
+
         # The `state_id` is a unique ID we generate that can be used as ID for
         # this collection of state. Usually this would be the same as the
         # state group, but on worker instances we can't generate a new state
@@ -245,11 +248,20 @@ class StateHandler(object):
             if key in context.prev_state_ids:
                 replaces = context.prev_state_ids[key]
                 event.unsigned["replaces_state"] = replaces
+
             context.current_state_ids = dict(context.prev_state_ids)
             context.current_state_ids[key] = event.event_id
+
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+            if context.delta_ids is not None:
+                context.delta_ids[key] = event.event_id
         else:
             context.current_state_ids = context.prev_state_ids
 
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
         context.prev_state_events = []
         defer.returnValue(context)
 
@@ -283,6 +295,8 @@ class StateHandler(object):
             defer.returnValue(_StateCacheEntry(
                 state=state_list,
                 state_group=name,
+                prev_group=name,
+                delta_ids={},
             ))
 
         with (yield self.resolve_linearizer.queue(group_names)):
@@ -340,9 +354,24 @@ class StateHandler(object):
                 if hasattr(self.store, "get_next_state_group"):
                     state_group = self.store.get_next_state_group()
 
+            prev_group = None
+            delta_ids = None
+            for old_group, old_ids in state_groups_ids.items():
+                if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
+                    n_delta_ids = {
+                        k: v
+                        for k, v in new_state.items()
+                        if old_ids.get(k) != v
+                    }
+                    if not delta_ids or len(n_delta_ids) < len(delta_ids):
+                        prev_group = old_group
+                        delta_ids = n_delta_ids
+
             cache = _StateCacheEntry(
                 state=new_state,
                 state_group=state_group,
+                prev_group=prev_group,
+                delta_ids=delta_ids,
             )
 
             if self._state_cache is not None:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 1a7d4c5199..7e9b351513 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -497,7 +497,11 @@ class EventsStore(SQLBaseStore):
 
                 # insert into the state_group, state_groups_state and
                 # event_to_state_groups tables.
-                self._store_mult_state_groups_txn(txn, ((event, context),))
+                try:
+                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                except Exception:
+                    logger.exception("")
+                    raise
 
                 metadata_json = encode_json(
                     event.internal_metadata.get_dict()
@@ -1543,6 +1547,9 @@ class EventsStore(SQLBaseStore):
         )
         event_rows = txn.fetchall()
 
+        for event_id, state_key in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
         txn.execute(
@@ -1571,26 +1578,26 @@ class EventsStore(SQLBaseStore):
 
         # Get all state groups that are only referenced by events that are
         # to be deleted.
-        txn.execute(
-            "SELECT state_group FROM event_to_state_groups"
-            " INNER JOIN events USING (event_id)"
-            " WHERE state_group IN ("
-            "   SELECT DISTINCT state_group FROM events"
-            "   INNER JOIN event_to_state_groups USING (event_id)"
-            "   WHERE room_id = ? AND topological_ordering < ?"
-            " )"
-            " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
-            (room_id, topological_ordering, topological_ordering)
-        )
-        state_rows = txn.fetchall()
-        txn.executemany(
-            "DELETE FROM state_groups_state WHERE state_group = ?",
-            state_rows
-        )
-        txn.executemany(
-            "DELETE FROM state_groups WHERE id = ?",
-            state_rows
-        )
+        # txn.execute(
+        #     "SELECT state_group FROM event_to_state_groups"
+        #     " INNER JOIN events USING (event_id)"
+        #     " WHERE state_group IN ("
+        #     "   SELECT DISTINCT state_group FROM events"
+        #     "   INNER JOIN event_to_state_groups USING (event_id)"
+        #     "   WHERE room_id = ? AND topological_ordering < ?"
+        #     " )"
+        #     " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
+        #     (room_id, topological_ordering, topological_ordering)
+        # )
+        # state_rows = txn.fetchall()
+        # txn.executemany(
+        #     "DELETE FROM state_groups_state WHERE state_group = ?",
+        #     state_rows
+        # )
+        # txn.executemany(
+        #     "DELETE FROM state_groups WHERE id = ?",
+        #     state_rows
+        # )
         # Delete all non-state
         txn.executemany(
             "DELETE FROM event_to_state_groups WHERE event_id = ?",
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b94ce7bea1..b1fbc4ffa5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 34
+SCHEMA_VERSION = 35
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/schema/delta/35/state.sql
new file mode 100644
index 0000000000..c4c244c169
--- /dev/null
+++ b/synapse/storage/schema/delta/35/state.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_group_edges(
+    state_group BIGINT NOT NULL,
+    prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ec551b0b4f..7f45c0cd99 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches import intern_string
+from synapse.storage.engines import PostgresEngine
 
 from twisted.internet import defer
 
@@ -24,6 +25,9 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+MAX_STATE_DELTA_HOPS = 100
+
+
 class StateStore(SQLBaseStore):
     """ Keeps track of the state at a given event.
 
@@ -103,7 +107,6 @@ class StateStore(SQLBaseStore):
             state_groups[event.event_id] = context.state_group
 
             if self._have_persisted_state_group_txn(txn, context.state_group):
-                logger.info("Already persisted state_group: %r", context.state_group)
                 continue
 
             state_event_ids = dict(context.current_state_ids)
@@ -118,20 +121,64 @@ class StateStore(SQLBaseStore):
                 },
             )
 
-            self._simple_insert_many_txn(
-                txn,
-                table="state_groups_state",
-                values=[
-                    {
-                        "state_group": context.state_group,
-                        "room_id": event.room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                        "event_id": state_id,
-                    }
-                    for key, state_id in state_event_ids.items()
-                ],
-            )
+            if context.prev_group:
+                potential_hops = self._count_state_group_hops_txn(
+                    txn, context.prev_group
+                )
+                if potential_hops < MAX_STATE_DELTA_HOPS:
+                    self._simple_insert_txn(
+                        txn,
+                        table="state_group_edges",
+                        values={
+                            "state_group": context.state_group,
+                            "prev_state_group": context.prev_group,
+                        },
+                    )
+
+                    self._simple_insert_many_txn(
+                        txn,
+                        table="state_groups_state",
+                        values=[
+                            {
+                                "state_group": context.state_group,
+                                "room_id": event.room_id,
+                                "type": key[0],
+                                "state_key": key[1],
+                                "event_id": state_id,
+                            }
+                            for key, state_id in context.delta_ids.items()
+                        ],
+                    )
+                else:
+                    self._simple_insert_many_txn(
+                        txn,
+                        table="state_groups_state",
+                        values=[
+                            {
+                                "state_group": context.state_group,
+                                "room_id": event.room_id,
+                                "type": key[0],
+                                "state_key": key[1],
+                                "event_id": state_id,
+                            }
+                            for key, state_id in context.current_state_ids.items()
+                        ],
+                    )
+            else:
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": context.state_group,
+                            "room_id": event.room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in state_event_ids.items()
+                    ],
+                )
 
         self._simple_insert_many_txn(
             txn,
@@ -145,6 +192,41 @@ class StateStore(SQLBaseStore):
             ],
         )
 
+    def _count_state_group_hops_txn(self, txn, state_group):
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = ("""
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """)
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
         if event_type and state_key is not None:
@@ -214,26 +296,70 @@ class StateStore(SQLBaseStore):
             else:
                 where_clause = ""
 
-            sql = (
-                "SELECT state_group, event_id, type, state_key"
-                " FROM state_groups_state WHERE"
-                " state_group IN (%s) %s" % (
-                    ",".join("?" for _ in groups),
-                    where_clause,
-                )
-            )
-
-            args = list(groups)
-            if types is not None:
-                args.extend([i for typ in types for i in typ])
-
-            txn.execute(sql, args)
-            rows = self.cursor_to_dict(txn)
-
             results = {group: {} for group in groups}
-            for row in rows:
-                key = (row["type"], row["state_key"])
-                results[row["state_group"]][key] = row["event_id"]
+            if isinstance(self.database_engine, PostgresEngine):
+                sql = ("""
+                    WITH RECURSIVE state(state_group) AS (
+                        VALUES(?::bigint)
+                        UNION ALL
+                        SELECT prev_state_group FROM state_group_edges e, state s
+                        WHERE s.state_group = e.state_group
+                    )
+                    SELECT type, state_key, event_id FROM state_groups_state
+                    WHERE ROW(type, state_key, state_group) IN (
+                        SELECT type, state_key, max(state_group) FROM state
+                        INNER JOIN state_groups_state USING (state_group)
+                        GROUP BY type, state_key
+                    )
+                    %s;
+                """) % (where_clause,)
+
+                for group in groups:
+                    args = [group]
+                    if types is not None:
+                        args.extend([i for typ in types for i in typ])
+
+                    txn.execute(sql, args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
+            else:
+                for group in groups:
+                    group_tree = [group]
+                    next_group = group
+
+                    while next_group:
+                        next_group = self._simple_select_one_onecol_txn(
+                            txn,
+                            table="state_group_edges",
+                            keyvalues={"state_group": next_group},
+                            retcol="prev_state_group",
+                            allow_none=True,
+                        )
+                        if next_group:
+                            group_tree.append(next_group)
+
+                    sql = ("""
+                        SELECT type, state_key, event_id FROM state_groups_state
+                        INNER JOIN (
+                            SELECT type, state_key, max(state_group) as state_group
+                            FROM state_groups_state
+                            WHERE state_group IN (%s) %s
+                            GROUP BY type, state_key
+                        ) USING (type, state_key, state_group);
+                    """) % (",".join("?" for _ in group_tree), where_clause,)
+
+                    args = list(group_tree)
+                    if types is not None:
+                        args.extend([i for typ in types for i in typ])
+
+                    txn.execute(sql, args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
+
             return results
 
         results = {}
@@ -504,32 +630,5 @@ class StateStore(SQLBaseStore):
 
         defer.returnValue(results)
 
-    def get_all_new_state_groups(self, last_id, current_id, limit):
-        def get_all_new_state_groups_txn(txn):
-            sql = (
-                "SELECT id, room_id, event_id FROM state_groups"
-                " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            groups = txn.fetchall()
-
-            if not groups:
-                return ([], [])
-
-            lower_bound = groups[0][0]
-            upper_bound = groups[-1][0]
-            sql = (
-                "SELECT state_group, type, state_key, event_id"
-                " FROM state_groups_state"
-                " WHERE ? <= state_group AND state_group <= ?"
-            )
-
-            txn.execute(sql, (lower_bound, upper_bound))
-            state_group_state = txn.fetchall()
-            return (groups, state_group_state)
-        return self.runInteraction(
-            "get_all_new_state_groups", get_all_new_state_groups_txn
-        )
-
     def get_next_state_group(self):
         return self._state_groups_id_gen.get_next()