diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 73cebc7383..7f45c0cd99 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,6 +25,9 @@ import logging
logger = logging.getLogger(__name__)
+MAX_STATE_DELTA_HOPS = 100
+
+
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
@@ -104,7 +107,6 @@ class StateStore(SQLBaseStore):
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
- logger.info("Already persisted state_group: %r", context.state_group)
continue
state_event_ids = dict(context.current_state_ids)
@@ -120,29 +122,48 @@ class StateStore(SQLBaseStore):
)
if context.prev_group:
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": context.state_group,
- "prev_state_group": context.prev_group,
- },
+ potential_hops = self._count_state_group_hops_txn(
+ txn, context.prev_group
)
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
+ if potential_hops < MAX_STATE_DELTA_HOPS:
+ self._simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
"state_group": context.state_group,
- "room_id": event.room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in context.delta_ids.items()
- ],
- )
+ "prev_state_group": context.prev_group,
+ },
+ )
+
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": context.state_group,
+ "room_id": event.room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in context.delta_ids.items()
+ ],
+ )
+ else:
+ self._simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": context.state_group,
+ "room_id": event.room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in context.current_state_ids.items()
+ ],
+ )
else:
self._simple_insert_many_txn(
txn,
@@ -171,6 +192,41 @@ class StateStore(SQLBaseStore):
],
)
+ def _count_state_group_hops_txn(self, txn, state_group):
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = ("""
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """)
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
|