summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/state.py100
1 files changed, 78 insertions, 22 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 73cebc7383..7f45c0cd99 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,6 +25,9 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+MAX_STATE_DELTA_HOPS = 100
+
+
 class StateStore(SQLBaseStore):
     """ Keeps track of the state at a given event.
 
@@ -104,7 +107,6 @@ class StateStore(SQLBaseStore):
             state_groups[event.event_id] = context.state_group
 
             if self._have_persisted_state_group_txn(txn, context.state_group):
-                logger.info("Already persisted state_group: %r", context.state_group)
                 continue
 
             state_event_ids = dict(context.current_state_ids)
@@ -120,29 +122,48 @@ class StateStore(SQLBaseStore):
             )
 
             if context.prev_group:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
-                    },
+                potential_hops = self._count_state_group_hops_txn(
+                    txn, context.prev_group
                 )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
+                if potential_hops < MAX_STATE_DELTA_HOPS:
+                    self._simple_insert_txn(
+                        txn,
+                        table="state_group_edges",
+                        values={
                             "state_group": context.state_group,
-                            "room_id": event.room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in context.delta_ids.items()
-                    ],
-                )
+                            "prev_state_group": context.prev_group,
+                        },
+                    )
+
+                    self._simple_insert_many_txn(
+                        txn,
+                        table="state_groups_state",
+                        values=[
+                            {
+                                "state_group": context.state_group,
+                                "room_id": event.room_id,
+                                "type": key[0],
+                                "state_key": key[1],
+                                "event_id": state_id,
+                            }
+                            for key, state_id in context.delta_ids.items()
+                        ],
+                    )
+                else:
+                    self._simple_insert_many_txn(
+                        txn,
+                        table="state_groups_state",
+                        values=[
+                            {
+                                "state_group": context.state_group,
+                                "room_id": event.room_id,
+                                "type": key[0],
+                                "state_key": key[1],
+                                "event_id": state_id,
+                            }
+                            for key, state_id in context.current_state_ids.items()
+                        ],
+                    )
             else:
                 self._simple_insert_many_txn(
                     txn,
@@ -171,6 +192,41 @@ class StateStore(SQLBaseStore):
             ],
         )
 
+    def _count_state_group_hops_txn(self, txn, state_group):
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = ("""
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """)
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
         if event_type and state_key is not None: