diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 186f064036..3216b3f3c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -43,7 +43,6 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -100,14 +99,6 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
- def get_chain_id_txn(txn):
- txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
-
- self._event_chain_id_gen = build_sequence_generator(
- db.engine, get_chain_id_txn, "event_auth_chain_id"
- )
-
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@@ -466,9 +457,6 @@ class PersistEventsStore:
if not state_events:
return
- # Map from event ID to chain ID/sequence number.
- chain_map = {} # type: Dict[str, Tuple[int, int]]
-
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
@@ -479,19 +467,44 @@ class PersistEventsStore:
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
+ event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+ self._add_chain_cover_index(
+ txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ )
+
+ @staticmethod
+ def _add_chain_cover_index(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ ) -> None:
+ """Calculate the chain cover index for the given events.
+
+ Args:
+ event_to_room_id: Event ID to the room ID of the event
+ event_to_types: Event ID to type and state_key of the event
+ event_to_auth_chain: Event ID to list of auth event IDs of the
+ event (events with no auth events can be excluded).
+ """
+
+ # Map from event ID to chain ID/sequence number.
+ chain_map = {} # type: Dict[str, Tuple[int, int]]
# Set of event IDs to calculate chain ID/seq numbers for.
- events_to_calc_chain_id_for = set(state_events)
+ events_to_calc_chain_id_for = set(event_to_room_id)
# We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where
# we didn't have the auth chain when we first persisted.
- rows = self.db_pool.simple_select_many_txn(
+ rows = db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="room_id",
- iterable={e.room_id for e in state_events.values()},
+ iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"),
)
for row in rows:
@@ -502,7 +515,7 @@ class PersistEventsStore:
# (We could pull out the auth events for all rows at once using
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
- auth_events = self.db_pool.simple_select_onecol_txn(
+ auth_events = db_pool.simple_select_onecol_txn(
txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
)
@@ -551,9 +564,7 @@ class PersistEventsStore:
events_to_calc_chain_id_for.add(auth_id)
- event_to_auth_chain[
- auth_id
- ] = self.db_pool.simple_select_onecol_txn(
+ event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
txn,
"event_auth",
keyvalues={"event_id": auth_id},
@@ -582,16 +593,17 @@ class PersistEventsStore:
# the list of events to calculate chain IDs for next time
# around. (Otherwise we will have already added it to the
# table).
- event = state_events.get(event_id)
- if event:
- self.db_pool.simple_insert_txn(
+ room_id = event_to_room_id.get(event_id)
+ if room_id:
+ e_type, state_key = event_to_types[event_id]
+ db_pool.simple_insert_txn(
txn,
table="event_auth_chain_to_calculate",
values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": e_type,
+ "state_key": state_key,
},
)
@@ -617,7 +629,7 @@ class PersistEventsStore:
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
- for auth_id in event_to_auth_chain[event_id]:
+ for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
@@ -629,7 +641,7 @@ class PersistEventsStore:
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
- already_allocated = self.db_pool.simple_select_one_onecol_txn(
+ already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
@@ -650,14 +662,14 @@ class PersistEventsStore:
)
if not new_chain_tuple:
- new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+ new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
- self.db_pool.simple_insert_many_txn(
+ db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
values=[
@@ -666,7 +678,7 @@ class PersistEventsStore:
],
)
- self.db_pool.simple_delete_many_txn(
+ db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
@@ -699,7 +711,7 @@ class PersistEventsStore:
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
- rows = self.db_pool.simple_select_many_txn(
+ rows = db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
@@ -730,11 +742,11 @@ class PersistEventsStore:
# auth events (A, B) to check if B is reachable from A.
reduction = {
a_id
- for a_id in event_to_auth_chain[event_id]
+ for a_id in event_to_auth_chain.get(event_id, [])
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
- event_to_auth_chain[event_id], r=2,
+ event_to_auth_chain.get(event_id, []), r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]
@@ -763,7 +775,7 @@ class PersistEventsStore:
(chain_id, sequence_number), (target_id, target_seq)
)
- self.db_pool.simple_insert_many_txn(
+ db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
values=[
|