summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py161
1 files changed, 102 insertions, 59 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ec551b0b4f..73cebc7383 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches import intern_string
+from synapse.storage.engines import PostgresEngine
 
 from twisted.internet import defer
 
@@ -118,20 +119,45 @@ class StateStore(SQLBaseStore):
                 },
             )
 
-            self._simple_insert_many_txn(
-                txn,
-                table="state_groups_state",
-                values=[
-                    {
+            if context.prev_group:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={
                         "state_group": context.state_group,
-                        "room_id": event.room_id,
-                        "type": key[0],
-                        "state_key": key[1],
-                        "event_id": state_id,
-                    }
-                    for key, state_id in state_event_ids.items()
-                ],
-            )
+                        "prev_state_group": context.prev_group,
+                    },
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": context.state_group,
+                            "room_id": event.room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in context.delta_ids.items()
+                    ],
+                )
+            else:
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": context.state_group,
+                            "room_id": event.room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in state_event_ids.items()
+                    ],
+                )
 
         self._simple_insert_many_txn(
             txn,
@@ -214,26 +240,70 @@ class StateStore(SQLBaseStore):
             else:
                 where_clause = ""
 
-            sql = (
-                "SELECT state_group, event_id, type, state_key"
-                " FROM state_groups_state WHERE"
-                " state_group IN (%s) %s" % (
-                    ",".join("?" for _ in groups),
-                    where_clause,
-                )
-            )
-
-            args = list(groups)
-            if types is not None:
-                args.extend([i for typ in types for i in typ])
-
-            txn.execute(sql, args)
-            rows = self.cursor_to_dict(txn)
-
             results = {group: {} for group in groups}
-            for row in rows:
-                key = (row["type"], row["state_key"])
-                results[row["state_group"]][key] = row["event_id"]
+            if isinstance(self.database_engine, PostgresEngine):
+                sql = ("""
+                    WITH RECURSIVE state(state_group) AS (
+                        VALUES(?::bigint)
+                        UNION ALL
+                        SELECT prev_state_group FROM state_group_edges e, state s
+                        WHERE s.state_group = e.state_group
+                    )
+                    SELECT type, state_key, event_id FROM state_groups_state
+                    WHERE ROW(type, state_key, state_group) IN (
+                        SELECT type, state_key, max(state_group) FROM state
+                        INNER JOIN state_groups_state USING (state_group)
+                        GROUP BY type, state_key
+                    )
+                    %s;
+                """) % (where_clause,)
+
+                for group in groups:
+                    args = [group]
+                    if types is not None:
+                        args.extend([i for typ in types for i in typ])
+
+                    txn.execute(sql, args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
+            else:
+                for group in groups:
+                    group_tree = [group]
+                    next_group = group
+
+                    while next_group:
+                        next_group = self._simple_select_one_onecol_txn(
+                            txn,
+                            table="state_group_edges",
+                            keyvalues={"state_group": next_group},
+                            retcol="prev_state_group",
+                            allow_none=True,
+                        )
+                        if next_group:
+                            group_tree.append(next_group)
+
+                    sql = ("""
+                        SELECT type, state_key, event_id FROM state_groups_state
+                        INNER JOIN (
+                            SELECT type, state_key, max(state_group) as state_group
+                            FROM state_groups_state
+                            WHERE state_group IN (%s) %s
+                            GROUP BY type, state_key
+                        ) USING (type, state_key, state_group);
+                    """) % (",".join("?" for _ in group_tree), where_clause,)
+
+                    args = list(group_tree)
+                    if types is not None:
+                        args.extend([i for typ in types for i in typ])
+
+                    txn.execute(sql, args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
+
             return results
 
         results = {}
@@ -504,32 +574,5 @@ class StateStore(SQLBaseStore):
 
         defer.returnValue(results)
 
-    def get_all_new_state_groups(self, last_id, current_id, limit):
-        def get_all_new_state_groups_txn(txn):
-            sql = (
-                "SELECT id, room_id, event_id FROM state_groups"
-                " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            groups = txn.fetchall()
-
-            if not groups:
-                return ([], [])
-
-            lower_bound = groups[0][0]
-            upper_bound = groups[-1][0]
-            sql = (
-                "SELECT state_group, type, state_key, event_id"
-                " FROM state_groups_state"
-                " WHERE ? <= state_group AND state_group <= ?"
-            )
-
-            txn.execute(sql, (lower_bound, upper_bound))
-            state_group_state = txn.fetchall()
-            return (groups, state_group_state)
-        return self.runInteraction(
-            "get_all_new_state_groups", get_all_new_state_groups_txn
-        )
-
     def get_next_state_group(self):
         return self._state_groups_id_gen.get_next()