diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@
import logging
from typing import Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
@@ -28,6 +30,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+ """Return value for _calculate_chain_cover_txn.
+ """
+
+ # The last room_id/depth/stream processed.
+ room_id = attr.ib(type=str)
+ depth = attr.ib(type=int)
+ stream = attr.ib(type=int)
+
+ # Number of rows processed
+ processed_count = attr.ib(type=int)
+
+ # Map from room_id to last depth/stream processed for each room that we have
+ # processed all events for (i.e. the rooms we can flip the
+ # `has_auth_chain_index` for)
+ finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
current_room_id = progress.get("current_room_id", "")
- # Have we finished processing the current room.
- finished = progress.get("finished", True)
-
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
- # Have we set the `has_auth_chain_index` for the room yet.
- has_set_room_has_chain_index = progress.get(
- "has_set_room_has_chain_index", False
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ current_room_id,
+ last_depth,
+ last_stream,
+ batch_size,
+ single_room=False,
)
- if finished:
- # If we've finished with the previous room (or its our first
- # iteration) we move on to the next room.
-
- def _get_next_room(txn: Cursor) -> Optional[str]:
- sql = """
- SELECT room_id FROM rooms
- WHERE room_id > ?
- AND (
- NOT has_auth_chain_index
- OR has_auth_chain_index IS NULL
- )
- ORDER BY room_id
- LIMIT 1
- """
- txn.execute(sql, (current_room_id,))
- row = txn.fetchone()
- if row:
- return row[0]
+ finished = result.processed_count == 0
- return None
-
- current_room_id = await self.db_pool.runInteraction(
- "_chain_cover_index", _get_next_room
- )
- if not current_room_id:
- await self.db_pool.updates._end_background_update("chain_cover")
- return 0
-
- logger.debug("Adding chain cover to %s", current_room_id)
-
- def _calculate_auth_chain(
- txn: Cursor, last_depth: int, last_stream: int
- ) -> Tuple[int, int, int]:
- # Get the next set of events in the room (that we haven't already
- # computed chain cover for). We do this in topological order.
-
- # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
- # comparison, but that is not supported on older SQLite versions
- tuple_clause, tuple_args = make_tuple_comparison_clause(
- self.database_engine,
- [
- ("topological_ordering", last_depth),
- ("stream_ordering", last_stream),
- ],
- )
+ total_rows_processed = result.processed_count
+ current_room_id = result.room_id
+ last_depth = result.depth
+ last_stream = result.stream
- sql = """
- SELECT
- event_id, state_events.type, state_events.state_key,
- topological_ordering, stream_ordering
- FROM events
- INNER JOIN state_events USING (event_id)
- LEFT JOIN event_auth_chains USING (event_id)
- LEFT JOIN event_auth_chain_to_calculate USING (event_id)
- WHERE events.room_id = ?
- AND event_auth_chains.event_id IS NULL
- AND event_auth_chain_to_calculate.event_id IS NULL
- AND %(tuple_cmp)s
- ORDER BY topological_ordering, stream_ordering
- LIMIT ?
- """ % {
- "tuple_cmp": tuple_clause,
- }
-
- args = [current_room_id]
- args.extend(tuple_args)
- args.append(batch_size)
-
- txn.execute(sql, args)
- rows = txn.fetchall()
-
- # Put the results in the necessary format for
- # `_add_chain_cover_index`
- event_to_room_id = {row[0]: current_room_id for row in rows}
- event_to_types = {row[0]: (row[1], row[2]) for row in rows}
-
- new_last_depth = rows[-1][3] if rows else last_depth # type: int
- new_last_stream = rows[-1][4] if rows else last_stream # type: int
-
- count = len(rows)
-
- # We also need to fetch the auth events for them.
- auth_events = self.db_pool.simple_select_many_txn(
- txn,
- table="event_auth",
- column="event_id",
- iterable=event_to_room_id,
- keyvalues={},
- retcols=("event_id", "auth_id"),
- )
-
- event_to_auth_chain = {} # type: Dict[str, List[str]]
- for row in auth_events:
- event_to_auth_chain.setdefault(row["event_id"], []).append(
- row["auth_id"]
- )
-
- # Calculate and persist the chain cover index for this set of events.
- #
- # Annoyingly we need to gut wrench into the persit event store so that
- # we can reuse the function to calculate the chain cover for rooms.
- PersistEventsStore._add_chain_cover_index(
- txn,
- self.db_pool,
- event_to_room_id,
- event_to_types,
- event_to_auth_chain,
- )
-
- return new_last_depth, new_last_stream, count
-
- last_depth, last_stream, count = await self.db_pool.runInteraction(
- "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
- )
-
- total_rows_processed = count
-
- if count < batch_size and not has_set_room_has_chain_index:
+ for room_id, (depth, stream) in result.finished_room_map.items():
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
@@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.simple_update(
table="rooms",
- keyvalues={"room_id": current_room_id},
+ keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
- has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
- last_depth, last_stream, count = await self.db_pool.runInteraction(
- "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ room_id,
+ depth,
+ stream,
+ batch_size=None,
+ single_room=True,
)
- total_rows_processed += count
+ total_rows_processed += result.processed_count
- # Note that at this point its technically possible that more events
- # than our `batch_size` have been persisted without their chain
- # cover, so we need to continue processing this room if the last
- # count returned was equal to the `batch_size`.
+ if finished:
+ await self.db_pool.updates._end_background_update("chain_cover")
+ return total_rows_processed
- if count < batch_size:
- # We've finished calculating the index for this room, move on to the
- # next room.
- await self.db_pool.updates._background_update_progress(
- "chain_cover", {"current_room_id": current_room_id, "finished": True},
- )
- else:
- # We still have outstanding events to calculate the index for.
- await self.db_pool.updates._background_update_progress(
- "chain_cover",
- {
- "current_room_id": current_room_id,
- "last_depth": last_depth,
- "last_stream": last_stream,
- "has_auth_chain_index": has_set_room_has_chain_index,
- "finished": False,
- },
- )
+ await self.db_pool.updates._background_update_progress(
+ "chain_cover",
+ {
+ "current_room_id": current_room_id,
+ "last_depth": last_depth,
+ "last_stream": last_stream,
+ },
+ )
return total_rows_processed
+
+ def _calculate_chain_cover_txn(
+ self,
+ txn: Cursor,
+ last_room_id: str,
+ last_depth: int,
+ last_stream: int,
+ batch_size: Optional[int],
+ single_room: bool,
+ ) -> _CalculateChainCover:
+ """Calculate the chain cover for `batch_size` events, ordered by
+ `(room_id, depth, stream)`.
+
+ Args:
+ txn,
+ last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+ tuple to fetch results after.
+ batch_size: The maximum number of events to process. If None then
+ no limit.
+ single_room: Whether to calculate the index for just the given
+ room.
+ """
+
+ # Get the next set of events in the room (that we haven't already
+ # computed chain cover for). We do this in topological order.
+
+ # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+ # comparison, but that is not supported on older SQLite versions
+ tuple_clause, tuple_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [
+ ("events.room_id", last_room_id),
+ ("topological_ordering", last_depth),
+ ("stream_ordering", last_stream),
+ ],
+ )
+
+ extra_clause = ""
+ if single_room:
+ extra_clause = "AND events.room_id = ?"
+ tuple_args.append(last_room_id)
+
+ sql = """
+ SELECT
+ event_id, state_events.type, state_events.state_key,
+ topological_ordering, stream_ordering,
+ events.room_id
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+ WHERE event_auth_chains.event_id IS NULL
+ AND event_auth_chain_to_calculate.event_id IS NULL
+ AND %(tuple_cmp)s
+ %(extra)s
+ ORDER BY events.room_id, topological_ordering, stream_ordering
+ %(limit)s
+ """ % {
+ "tuple_cmp": tuple_clause,
+ "limit": "LIMIT ?" if batch_size is not None else "",
+ "extra": extra_clause,
+ }
+
+ if batch_size is not None:
+ tuple_args.append(batch_size)
+
+ txn.execute(sql, tuple_args)
+ rows = txn.fetchall()
+
+ # Put the results in the necessary format for
+ # `_add_chain_cover_index`
+ event_to_room_id = {row[0]: row[5] for row in rows}
+ event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+ # Calculate the new last position we've processed up to.
+ new_last_depth = rows[-1][3] if rows else last_depth # type: int
+ new_last_stream = rows[-1][4] if rows else last_stream # type: int
+ new_last_room_id = rows[-1][5] if rows else "" # type: str
+
+ # Map from room_id to last depth/stream_ordering processed for the room,
+ # excluding the last room (which we're likely still processing). We also
+ # need to include the room passed in if it's not included in the result
+ # set (as we then know we've processed all events in said room).
+ #
+ # This is the set of rooms that we can now safely flip the
+ # `has_auth_chain_index` bit for.
+ finished_rooms = {
+ row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+ }
+ if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+ finished_rooms[last_room_id] = (last_depth, last_stream)
+
+ count = len(rows)
+
+ # We also need to fetch the auth events for them.
+ auth_events = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ )
+
+ event_to_auth_chain = {} # type: Dict[str, List[str]]
+ for row in auth_events:
+ event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+ # Calculate and persist the chain cover index for this set of events.
+ #
+ # Annoyingly we need to gut wrench into the persit event store so that
+ # we can reuse the function to calculate the chain cover for rooms.
+ PersistEventsStore._add_chain_cover_index(
+ txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ )
+
+ return _CalculateChainCover(
+ room_id=new_last_room_id,
+ depth=new_last_depth,
+ stream=new_last_stream,
+ processed_count=count,
+ finished_room_map=finished_rooms,
+ )
|