summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/snapshot.py16
-rw-r--r--synapse/state.py34
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/35/state.sql21
-rw-r--r--synapse/storage/state.py161
5 files changed, 172 insertions, 62 deletions
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e895b1c450..ec32008d5a 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,9 +15,25 @@
 
 
 class EventContext(object):
+    __slots__ = [
+        "current_state_ids",
+        "prev_state_ids",
+        "state_group",
+        "rejected",
+        "push_actions",
+        "prev_group",
+        "delta_ids",
+        "prev_state_events",
+    ]
+
     def __init__(self):
         self.current_state_ids = None
         self.prev_state_ids = None
         self.state_group = None
         self.rejected = False
         self.push_actions = []
+
+        self.prev_group = None
+        self.delta_ids = None
+
+        self.prev_state_events = None
diff --git a/synapse/state.py b/synapse/state.py
index b31bbcdbd2..cd428e83cd 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -54,12 +54,15 @@ def _gen_state_id():
 
 
 class _StateCacheEntry(object):
-    __slots__ = ["state", "state_group", "state_id"]
+    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
-    def __init__(self, state, state_group):
+    def __init__(self, state, state_group, prev_group=None, delta_ids=None):
         self.state = state
         self.state_group = state_group
 
+        self.prev_group = prev_group
+        self.delta_ids = delta_ids
+
         # The `state_id` is a unique ID we generate that can be used as ID for
         # this collection of state. Usually this would be the same as the
         # state group, but on worker instances we can't generate a new state
@@ -243,11 +246,20 @@ class StateHandler(object):
             if key in context.prev_state_ids:
                 replaces = context.prev_state_ids[key]
                 event.unsigned["replaces_state"] = replaces
+
             context.current_state_ids = dict(context.prev_state_ids)
             context.current_state_ids[key] = event.event_id
+
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+            if context.delta_ids is not None:
+                context.delta_ids[key] = event.event_id
         else:
             context.current_state_ids = context.prev_state_ids
 
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
         context.prev_state_events = []
         defer.returnValue(context)
 
@@ -281,6 +293,8 @@ class StateHandler(object):
             defer.returnValue(_StateCacheEntry(
                 state=state_list,
                 state_group=name,
+                prev_group=name,
+                delta_ids={},
             ))
 
         if self._state_cache is not None:
@@ -330,6 +344,7 @@ class StateHandler(object):
             if new_state_event_ids == frozenset(e_id for e_id in events):
                 state_group = sg
                 break
+
         if state_group is None:
             # Worker instances don't have access to this method, but we want
             # to set the state_group on the main instance to increase cache
@@ -337,9 +352,24 @@ class StateHandler(object):
             if hasattr(self.store, "get_next_state_group"):
                 state_group = self.store.get_next_state_group()
 
+        prev_group = None
+        delta_ids = None
+        for old_group, old_ids in state_groups_ids.items():
+            if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
+                n_delta_ids = {
+                    k: v
+                    for k, v in new_state.items()
+                    if old_ids.get(k) != v
+                }
+                if not delta_ids or len(n_delta_ids) < len(delta_ids):
+                    prev_group = old_group
+                    delta_ids = n_delta_ids
+
         cache = _StateCacheEntry(
             state=new_state,
             state_group=state_group,
+            prev_group=prev_group,
+            delta_ids=delta_ids,
         )
 
         if self._state_cache is not None:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b94ce7bea1..b1fbc4ffa5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 34
+SCHEMA_VERSION = 35
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/schema/delta/35/state.sql
new file mode 100644
index 0000000000..c4c244c169
--- /dev/null
+++ b/synapse/storage/schema/delta/35/state.sql
@@ -0,0 +1,21 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_group_edges(
+    state_group BIGINT NOT NULL,
+    prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ec551b0b4f..73cebc7383 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches import intern_string
+from synapse.storage.engines import PostgresEngine
 
 from twisted.internet import defer
 
@@ -118,20 +119,45 @@ class StateStore(SQLBaseStore):
                 },
             )
 
-            self._simple_insert_many_txn(
-                txn,
-                table="state_groups_state",
-                values=[
-                    {
+            if context.prev_group:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={
                         "state_group": context.state_group,
-                        "room_id": event.room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                        "event_id": state_id,
-                    }
-                    for key, state_id in state_event_ids.items()
-                ],
-            )
+                        "prev_state_group": context.prev_group,
+                    },
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": context.state_group,
+                            "room_id": event.room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in context.delta_ids.items()
+                    ],
+                )
+            else:
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": context.state_group,
+                            "room_id": event.room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in state_event_ids.items()
+                    ],
+                )
 
         self._simple_insert_many_txn(
             txn,
@@ -214,26 +240,70 @@ class StateStore(SQLBaseStore):
             else:
                 where_clause = ""
 
-            sql = (
-                "SELECT state_group, event_id, type, state_key"
-                " FROM state_groups_state WHERE"
-                " state_group IN (%s) %s" % (
-                    ",".join("?" for _ in groups),
-                    where_clause,
-                )
-            )
-
-            args = list(groups)
-            if types is not None:
-                args.extend([i for typ in types for i in typ])
-
-            txn.execute(sql, args)
-            rows = self.cursor_to_dict(txn)
-
             results = {group: {} for group in groups}
-            for row in rows:
-                key = (row["type"], row["state_key"])
-                results[row["state_group"]][key] = row["event_id"]
+            if isinstance(self.database_engine, PostgresEngine):
+                sql = ("""
+                    WITH RECURSIVE state(state_group) AS (
+                        VALUES(?::bigint)
+                        UNION ALL
+                        SELECT prev_state_group FROM state_group_edges e, state s
+                        WHERE s.state_group = e.state_group
+                    )
+                    SELECT type, state_key, event_id FROM state_groups_state
+                    WHERE ROW(type, state_key, state_group) IN (
+                        SELECT type, state_key, max(state_group) FROM state
+                        INNER JOIN state_groups_state USING (state_group)
+                        GROUP BY type, state_key
+                    )
+                    %s;
+                """) % (where_clause,)
+
+                for group in groups:
+                    args = [group]
+                    if types is not None:
+                        args.extend([i for typ in types for i in typ])
+
+                    txn.execute(sql, args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
+            else:
+                for group in groups:
+                    group_tree = [group]
+                    next_group = group
+
+                    while next_group:
+                        next_group = self._simple_select_one_onecol_txn(
+                            txn,
+                            table="state_group_edges",
+                            keyvalues={"state_group": next_group},
+                            retcol="prev_state_group",
+                            allow_none=True,
+                        )
+                        if next_group:
+                            group_tree.append(next_group)
+
+                    sql = ("""
+                        SELECT type, state_key, event_id FROM state_groups_state
+                        INNER JOIN (
+                            SELECT type, state_key, max(state_group) as state_group
+                            FROM state_groups_state
+                            WHERE state_group IN (%s) %s
+                            GROUP BY type, state_key
+                        ) USING (type, state_key, state_group);
+                    """) % (",".join("?" for _ in group_tree), where_clause,)
+
+                    args = list(group_tree)
+                    if types is not None:
+                        args.extend([i for typ in types for i in typ])
+
+                    txn.execute(sql, args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
+
             return results
 
         results = {}
@@ -504,32 +574,5 @@ class StateStore(SQLBaseStore):
 
         defer.returnValue(results)
 
-    def get_all_new_state_groups(self, last_id, current_id, limit):
-        def get_all_new_state_groups_txn(txn):
-            sql = (
-                "SELECT id, room_id, event_id FROM state_groups"
-                " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            groups = txn.fetchall()
-
-            if not groups:
-                return ([], [])
-
-            lower_bound = groups[0][0]
-            upper_bound = groups[-1][0]
-            sql = (
-                "SELECT state_group, type, state_key, event_id"
-                " FROM state_groups_state"
-                " WHERE ? <= state_group AND state_group <= ?"
-            )
-
-            txn.execute(sql, (lower_bound, upper_bound))
-            state_group_state = txn.fetchall()
-            return (groups, state_group_state)
-        return self.runInteraction(
-            "get_all_new_state_groups", get_all_new_state_groups_txn
-        )
-
     def get_next_state_group(self):
         return self._state_groups_id_gen.get_next()