summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-09-02 10:41:38 +0100
committerErik Johnston <erik@matrix.org>2016-09-02 10:41:38 +0100
commit598317927cb8f741528d639f3ce875299fde478e (patch)
tree238e68353e38a5f5d87ee4dcddfadcea51241ee1 /synapse/storage
parentMove to storing state_groups_state as deltas (diff)
downloadsynapse-598317927cb8f741528d639f3ce875299fde478e.tar.xz
Limit the length of state chains
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py49
-rw-r--r--synapse/storage/state.py100
2 files changed, 106 insertions, 43 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 1a7d4c5199..7e9b351513 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -497,7 +497,11 @@ class EventsStore(SQLBaseStore):
 
                 # insert into the state_group, state_groups_state and
                 # event_to_state_groups tables.
-                self._store_mult_state_groups_txn(txn, ((event, context),))
+                try:
+                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                except Exception:
+                    logger.exception("")
+                    raise
 
                 metadata_json = encode_json(
                     event.internal_metadata.get_dict()
@@ -1543,6 +1547,9 @@ class EventsStore(SQLBaseStore):
         )
         event_rows = txn.fetchall()
 
+        for event_id, state_key in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
         txn.execute(
@@ -1571,26 +1578,26 @@ class EventsStore(SQLBaseStore):
 
         # Get all state groups that are only referenced by events that are
         # to be deleted.
-        txn.execute(
-            "SELECT state_group FROM event_to_state_groups"
-            " INNER JOIN events USING (event_id)"
-            " WHERE state_group IN ("
-            "   SELECT DISTINCT state_group FROM events"
-            "   INNER JOIN event_to_state_groups USING (event_id)"
-            "   WHERE room_id = ? AND topological_ordering < ?"
-            " )"
-            " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
-            (room_id, topological_ordering, topological_ordering)
-        )
-        state_rows = txn.fetchall()
-        txn.executemany(
-            "DELETE FROM state_groups_state WHERE state_group = ?",
-            state_rows
-        )
-        txn.executemany(
-            "DELETE FROM state_groups WHERE id = ?",
-            state_rows
-        )
+        # txn.execute(
+        #     "SELECT state_group FROM event_to_state_groups"
+        #     " INNER JOIN events USING (event_id)"
+        #     " WHERE state_group IN ("
+        #     "   SELECT DISTINCT state_group FROM events"
+        #     "   INNER JOIN event_to_state_groups USING (event_id)"
+        #     "   WHERE room_id = ? AND topological_ordering < ?"
+        #     " )"
+        #     " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
+        #     (room_id, topological_ordering, topological_ordering)
+        # )
+        # state_rows = txn.fetchall()
+        # txn.executemany(
+        #     "DELETE FROM state_groups_state WHERE state_group = ?",
+        #     state_rows
+        # )
+        # txn.executemany(
+        #     "DELETE FROM state_groups WHERE id = ?",
+        #     state_rows
+        # )
         # Delete all non-state
         txn.executemany(
             "DELETE FROM event_to_state_groups WHERE event_id = ?",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 73cebc7383..7f45c0cd99 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,6 +25,9 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+MAX_STATE_DELTA_HOPS = 100
+
+
 class StateStore(SQLBaseStore):
     """ Keeps track of the state at a given event.
 
@@ -104,7 +107,6 @@ class StateStore(SQLBaseStore):
             state_groups[event.event_id] = context.state_group
 
             if self._have_persisted_state_group_txn(txn, context.state_group):
-                logger.info("Already persisted state_group: %r", context.state_group)
                 continue
 
             state_event_ids = dict(context.current_state_ids)
@@ -120,29 +122,48 @@ class StateStore(SQLBaseStore):
             )
 
             if context.prev_group:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
-                    },
+                potential_hops = self._count_state_group_hops_txn(
+                    txn, context.prev_group
                 )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
+                if potential_hops < MAX_STATE_DELTA_HOPS:
+                    self._simple_insert_txn(
+                        txn,
+                        table="state_group_edges",
+                        values={
                             "state_group": context.state_group,
-                            "room_id": event.room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in context.delta_ids.items()
-                    ],
-                )
+                            "prev_state_group": context.prev_group,
+                        },
+                    )
+
+                    self._simple_insert_many_txn(
+                        txn,
+                        table="state_groups_state",
+                        values=[
+                            {
+                                "state_group": context.state_group,
+                                "room_id": event.room_id,
+                                "type": key[0],
+                                "state_key": key[1],
+                                "event_id": state_id,
+                            }
+                            for key, state_id in context.delta_ids.items()
+                        ],
+                    )
+                else:
+                    self._simple_insert_many_txn(
+                        txn,
+                        table="state_groups_state",
+                        values=[
+                            {
+                                "state_group": context.state_group,
+                                "room_id": event.room_id,
+                                "type": key[0],
+                                "state_key": key[1],
+                                "event_id": state_id,
+                            }
+                            for key, state_id in context.current_state_ids.items()
+                        ],
+                    )
             else:
                 self._simple_insert_many_txn(
                     txn,
@@ -171,6 +192,41 @@ class StateStore(SQLBaseStore):
             ],
         )
 
+    def _count_state_group_hops_txn(self, txn, state_group):
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = ("""
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """)
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
         if event_type and state_key is not None: