summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py44
1 files changed, 19 insertions, 25 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 897fa06639..a396a201d4 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -109,10 +109,8 @@ class PersistEventsStore:
 
         # Ideally we'd move these ID gens here, unfortunately some other ID
         # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen = (
-            self.store._backfill_id_gen
-        )  # type: MultiWriterIdGenerator
-        self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
+        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
 
         # This should only exist on instances that are configured to write
         assert (
@@ -221,7 +219,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []  # type: List[str]
+        results: List[str] = []
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -508,7 +506,7 @@ class PersistEventsStore:
         """
 
         # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+        chain_map: Dict[str, Tuple[int, int]] = {}
 
         # Set of event IDs to calculate chain ID/seq numbers for.
         events_to_calc_chain_id_for = set(event_to_room_id)
@@ -817,8 +815,8 @@ class PersistEventsStore:
         #      new chain if the sequence number has already been allocated.
         #
 
-        existing_chains = set()  # type: Set[int]
-        tree = []  # type: List[Tuple[str, Optional[str]]]
+        existing_chains: Set[int] = set()
+        tree: List[Tuple[str, Optional[str]]] = []
 
         # We need to do this in a topologically sorted order as we want to
         # generate chain IDs/sequence numbers of an event's auth events before
@@ -848,7 +846,7 @@ class PersistEventsStore:
         )
         txn.execute(sql % (clause,), args)
 
-        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+        chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
 
         # Allocate the new events chain ID/sequence numbers.
         #
@@ -858,8 +856,8 @@ class PersistEventsStore:
         # number of new chain IDs in one call, replacing all temporary
         # objects with real allocated chain IDs.
 
-        unallocated_chain_ids = set()  # type: Set[object]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        unallocated_chain_ids: Set[object] = set()
+        new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
         for event_id, auth_event_id in tree:
             # If we reference an auth_event_id we fetch the allocated chain ID,
             # either from the existing `chain_map` or the newly generated
@@ -870,7 +868,7 @@ class PersistEventsStore:
                 if not existing_chain_id:
                     existing_chain_id = chain_map[auth_event_id]
 
-            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            new_chain_tuple: Optional[Tuple[Any, int]] = None
             if existing_chain_id:
                 # We found a chain ID/sequence number candidate, check its
                 # not already taken.
@@ -897,9 +895,9 @@ class PersistEventsStore:
         )
 
         # Map from potentially temporary chain ID to real chain ID
-        chain_id_to_allocated_map = dict(
+        chain_id_to_allocated_map: Dict[Any, int] = dict(
             zip(unallocated_chain_ids, newly_allocated_chain_ids)
-        )  # type: Dict[Any, int]
+        )
         chain_id_to_allocated_map.update((c, c) for c in existing_chains)
 
         return {
@@ -1175,9 +1173,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = (
-            OrderedDict()
-        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
+        new_events_and_contexts: OrderedDict[
+            str, Tuple[EventBase, EventContext]
+        ] = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -1205,7 +1203,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}  # type: Dict[str, int]
+        depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1580,11 +1578,11 @@ class PersistEventsStore:
         # invalidate the cache for the redacted event
         txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
 
-        self.db_pool.simple_insert_txn(
+        self.db_pool.simple_upsert_txn(
             txn,
             table="redactions",
+            keyvalues={"event_id": event.event_id},
             values={
-                "event_id": event.event_id,
                 "redacts": event.redacts,
                 "received_ts": self._clock.time_msec(),
             },
@@ -1885,7 +1883,7 @@ class PersistEventsStore:
                 ),
             )
 
-            room_to_event_ids = {}  # type: Dict[str, List[str]]
+            room_to_event_ids: Dict[str, List[str]] = {}
             for e, _ in events_and_contexts:
                 room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
 
@@ -2012,10 +2010,6 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}  # type: Dict[str, List[EventBase]]
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
         query = (
             "INSERT INTO event_backward_extremities (event_id, room_id)"
             " SELECT ?, ? WHERE NOT EXISTS ("