summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py82
1 files changed, 47 insertions, 35 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py

index 186f064036..3216b3f3c8 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -43,7 +43,6 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -100,14 +99,6 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() - def get_chain_id_txn(txn): - txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] - - self._event_chain_id_gen = build_sequence_generator( - db.engine, get_chain_id_txn, "event_auth_chain_id" - ) - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -466,9 +457,6 @@ class PersistEventsStore: if not state_events: return - # Map from event ID to chain ID/sequence number. - chain_map = {} # type: Dict[str, Tuple[int, int]] - # We need to know the type/state_key and auth events of the events we're # calculating chain IDs for. We don't rely on having the full Event # instances as we'll potentially be pulling more events from the DB and @@ -479,19 +467,44 @@ class PersistEventsStore: event_to_auth_chain = { e.event_id: e.auth_event_ids() for e in state_events.values() } + event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} + + self._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + @staticmethod + def _add_chain_cover_index( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + ) -> None: + """Calculate the chain cover index for the given events. + + Args: + event_to_room_id: Event ID to the room ID of the event + event_to_types: Event ID to type and state_key of the event + event_to_auth_chain: Event ID to list of auth event IDs of the + event (events with no auth events can be excluded). + """ + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] # Set of event IDs to calculate chain ID/seq numbers for. - events_to_calc_chain_id_for = set(state_events) + events_to_calc_chain_id_for = set(event_to_room_id) # We check if there are any events that need to be handled in the rooms # we're looking at. These should just be out of band memberships, where # we didn't have the auth chain when we first persisted. - rows = self.db_pool.simple_select_many_txn( + rows = db_pool.simple_select_many_txn( txn, table="event_auth_chain_to_calculate", keyvalues={}, column="room_id", - iterable={e.room_id for e in state_events.values()}, + iterable=set(event_to_room_id.values()), retcols=("event_id", "type", "state_key"), ) for row in rows: @@ -502,7 +515,7 @@ class PersistEventsStore: # (We could pull out the auth events for all rows at once using # simple_select_many, but this case happens rarely and almost always # with a single row.) - auth_events = self.db_pool.simple_select_onecol_txn( + auth_events = db_pool.simple_select_onecol_txn( txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", ) @@ -551,9 +564,7 @@ class PersistEventsStore: events_to_calc_chain_id_for.add(auth_id) - event_to_auth_chain[ - auth_id - ] = self.db_pool.simple_select_onecol_txn( + event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn( txn, "event_auth", keyvalues={"event_id": auth_id}, @@ -582,16 +593,17 @@ class PersistEventsStore: # the list of events to calculate chain IDs for next time # around. (Otherwise we will have already added it to the # table). - event = state_events.get(event_id) - if event: - self.db_pool.simple_insert_txn( + room_id = event_to_room_id.get(event_id) + if room_id: + e_type, state_key = event_to_types[event_id] + db_pool.simple_insert_txn( txn, table="event_auth_chain_to_calculate", values={ - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, + "event_id": event_id, + "room_id": room_id, + "type": e_type, + "state_key": state_key, }, ) @@ -617,7 +629,7 @@ class PersistEventsStore: events_to_calc_chain_id_for, event_to_auth_chain ): existing_chain_id = None - for auth_id in event_to_auth_chain[event_id]: + for auth_id in event_to_auth_chain.get(event_id, []): if event_to_types.get(event_id) == event_to_types.get(auth_id): existing_chain_id = chain_map[auth_id] break @@ -629,7 +641,7 @@ class PersistEventsStore: proposed_new_id = existing_chain_id[0] proposed_new_seq = existing_chain_id[1] + 1 if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: - already_allocated = self.db_pool.simple_select_one_onecol_txn( + already_allocated = db_pool.simple_select_one_onecol_txn( txn, table="event_auth_chains", keyvalues={ @@ -650,14 +662,14 @@ class PersistEventsStore: ) if not new_chain_tuple: - new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1) + new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) chains_tuples_allocated.add(new_chain_tuple) chain_map[event_id] = new_chain_tuple new_chain_tuples[event_id] = new_chain_tuple - self.db_pool.simple_insert_many_txn( + db_pool.simple_insert_many_txn( txn, table="event_auth_chains", values=[ @@ -666,7 +678,7 @@ class PersistEventsStore: ], ) - self.db_pool.simple_delete_many_txn( + db_pool.simple_delete_many_txn( txn, table="event_auth_chain_to_calculate", keyvalues={}, @@ -699,7 +711,7 @@ class PersistEventsStore: # Step 1, fetch all existing links from all the chains we've seen # referenced. chain_links = _LinkMap() - rows = self.db_pool.simple_select_many_txn( + rows = db_pool.simple_select_many_txn( txn, table="event_auth_chain_links", column="origin_chain_id", @@ -730,11 +742,11 @@ class PersistEventsStore: # auth events (A, B) to check if B is reachable from A. reduction = { a_id - for a_id in event_to_auth_chain[event_id] + for a_id in event_to_auth_chain.get(event_id, []) if chain_map[a_id][0] != chain_id } for start_auth_id, end_auth_id in itertools.permutations( - event_to_auth_chain[event_id], r=2, + event_to_auth_chain.get(event_id, []), r=2, ): if chain_links.exists_path_from( chain_map[start_auth_id], chain_map[end_auth_id] @@ -763,7 +775,7 @@ class PersistEventsStore: (chain_id, sequence_number), (target_id, target_seq) ) - self.db_pool.simple_insert_many_txn( + db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", values=[