summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py50
1 files changed, 36 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 186f064036..e0fbcc58cf 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -466,9 +466,6 @@ class PersistEventsStore:
         if not state_events:
             return
 
-        # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
-
         # We need to know the type/state_key and auth events of the events we're
         # calculating chain IDs for. We don't rely on having the full Event
         # instances as we'll potentially be pulling more events from the DB and
@@ -479,9 +476,33 @@ class PersistEventsStore:
         event_to_auth_chain = {
             e.event_id: e.auth_event_ids() for e in state_events.values()
         }
+        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+        self._add_chain_cover_index(
+            txn, event_to_room_id, event_to_types, event_to_auth_chain
+        )
+
+    def _add_chain_cover_index(
+        self,
+        txn,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+    ) -> None:
+        """Calculate the chain cover index for the given events.
+
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
 
         # Set of event IDs to calculate chain ID/seq numbers for.
-        events_to_calc_chain_id_for = set(state_events)
+        events_to_calc_chain_id_for = set(event_to_room_id)
 
         # We check if there are any events that need to be handled in the rooms
         # we're looking at. These should just be out of band memberships, where
@@ -491,7 +512,7 @@ class PersistEventsStore:
             table="event_auth_chain_to_calculate",
             keyvalues={},
             column="room_id",
-            iterable={e.room_id for e in state_events.values()},
+            iterable=set(event_to_room_id.values()),
             retcols=("event_id", "type", "state_key"),
         )
         for row in rows:
@@ -582,16 +603,17 @@ class PersistEventsStore:
                     # the list of events to calculate chain IDs for next time
                     # around. (Otherwise we will have already added it to the
                     # table).
-                    event = state_events.get(event_id)
-                    if event:
+                    room_id = event_to_room_id.get(event_id)
+                    if room_id:
+                        e_type, state_key = event_to_types[event_id]
                         self.db_pool.simple_insert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
                             values={
-                                "event_id": event.event_id,
-                                "room_id": event.room_id,
-                                "type": event.type,
-                                "state_key": event.state_key,
+                                "event_id": event_id,
+                                "room_id": room_id,
+                                "type": e_type,
+                                "state_key": state_key,
                             },
                         )
 
@@ -617,7 +639,7 @@ class PersistEventsStore:
             events_to_calc_chain_id_for, event_to_auth_chain
         ):
             existing_chain_id = None
-            for auth_id in event_to_auth_chain[event_id]:
+            for auth_id in event_to_auth_chain.get(event_id, []):
                 if event_to_types.get(event_id) == event_to_types.get(auth_id):
                     existing_chain_id = chain_map[auth_id]
                     break
@@ -730,11 +752,11 @@ class PersistEventsStore:
             # auth events (A, B) to check if B is reachable from A.
             reduction = {
                 a_id
-                for a_id in event_to_auth_chain[event_id]
+                for a_id in event_to_auth_chain.get(event_id, [])
                 if chain_map[a_id][0] != chain_id
             }
             for start_auth_id, end_auth_id in itertools.permutations(
-                event_to_auth_chain[event_id], r=2,
+                event_to_auth_chain.get(event_id, []), r=2,
             ):
                 if chain_links.exists_path_from(
                     chain_map[start_auth_id], chain_map[end_auth_id]