diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 519059c306..e8133de2fa 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -37,25 +37,55 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively.
"""
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
- def get_auth_chain(self, event_ids):
- return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+ def get_auth_chain(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
+
+ Returns:
+ list of events
+ """
+ return self.get_auth_chain_ids(
+ event_ids, include_given=include_given,
+ ).addCallback(self._get_events)
+
+ def get_auth_chain_ids(self, event_ids, include_given=False):
+ """Get auth events for given event_ids. The events *must* be state events.
+
+ Args:
+ event_ids (list): state events
+ include_given (bool): include the given events in result
- def get_auth_chain_ids(self, event_ids):
+ Returns:
+ list of event_ids
+ """
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
- event_ids
+ event_ids, include_given
)
- def _get_auth_chain_ids_txn(self, txn, event_ids):
- results = set()
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ if include_given:
+ results = set(event_ids)
+ else:
+ results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
@@ -504,3 +534,52 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+
+ @defer.inlineCallbacks
+ def _background_delete_non_state_event_auth(self, progress, batch_size):
+ def delete_event_auth(txn):
+ target_min_stream_id = progress.get("target_min_stream_id_inclusive")
+ max_stream_id = progress.get("max_stream_id_exclusive")
+
+ if not target_min_stream_id or not max_stream_id:
+ txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ target_min_stream_id = rows[0][0]
+
+ txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
+ rows = txn.fetchall()
+ max_stream_id = rows[0][0]
+
+ min_stream_id = max_stream_id - batch_size
+
+ sql = """
+ DELETE FROM event_auth
+ WHERE event_id IN (
+ SELECT event_id FROM events
+ LEFT JOIN state_events USING (room_id, event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND state_key IS null
+ )
+ """
+
+ txn.execute(sql, (min_stream_id, max_stream_id,))
+
+ new_progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_AUTH_STATE_ONLY, new_progress
+ )
+
+ return min_stream_id >= target_min_stream_id
+
+ result = yield self.runInteraction(
+ self.EVENT_AUTH_STATE_ONLY, delete_event_auth
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+
+ defer.returnValue(batch_size)
|