diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index eeca85fc94..6e8aeed7b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -67,6 +67,8 @@ class _BackgroundUpdates:
EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
+ EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
@@ -253,6 +255,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
replaces_index="ev_edges_id",
)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+ self._background_events_populate_state_key_rejections,
+ )
+
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1399,3 +1406,83 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
return batch_size
+
+ async def _background_events_populate_state_key_rejections(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Back-populate `events.state_key` and `events.rejection_reason"""
+
+ min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
+ max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
+
+ def _populate_txn(txn: LoggingTransaction) -> bool:
+ """Returns True if we're done."""
+
+ # first we need to find an endpoint.
+ # we need to find the final row in the batch of batch_size, which means
+ # we need to skip over (batch_size-1) rows and get the next row.
+ txn.execute(
+ """
+ SELECT stream_ordering FROM events
+ WHERE stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ LIMIT 1 OFFSET ?
+ """,
+ (
+ min_stream_ordering_exclusive,
+ max_stream_ordering_inclusive,
+ batch_size - 1,
+ ),
+ )
+
+ endpoint = None
+ row = txn.fetchone()
+ if row:
+ endpoint = row[0]
+
+ where_clause = "stream_ordering > ?"
+ args = [min_stream_ordering_exclusive]
+ if endpoint:
+ where_clause += " AND stream_ordering <= ?"
+ args.append(endpoint)
+
+ # now do the updates.
+ txn.execute(
+ f"""
+ UPDATE events
+ SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
+ rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
+ WHERE ({where_clause})
+ """,
+ args,
+ )
+
+ logger.info(
+ "populated new `events` columns up to %s/%i: updated %i rows",
+ endpoint,
+ max_stream_ordering_inclusive,
+ txn.rowcount,
+ )
+
+ if endpoint is None:
+ # we're done
+ return True
+
+ progress["min_stream_ordering_exclusive"] = endpoint
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+ progress,
+ )
+ return False
+
+ done = await self.db_pool.runInteraction(
+ desc="events_populate_state_key_rejections", func=_populate_txn
+ )
+
+ if done:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
+ )
+
+ return batch_size
|