summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py156
1 files changed, 99 insertions, 57 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ad611b2c0b..6c12653bb3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -49,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 
@@ -235,7 +235,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -285,7 +287,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -515,7 +519,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -809,7 +813,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -943,7 +947,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -997,7 +1001,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1155,7 +1159,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1189,7 +1193,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1254,9 +1258,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1385,7 +1389,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1476,18 +1480,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1508,7 +1514,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1602,7 +1608,11 @@ class PersistEventsStore:
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1633,7 +1643,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1663,19 +1673,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1684,25 +1699,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1742,6 +1764,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1838,7 +1861,9 @@ class PersistEventsStore:
                 (parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1899,7 +1924,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -1999,25 +2024,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2077,8 +2106,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2086,12 +2118,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2160,7 +2190,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2171,7 +2203,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2183,8 +2217,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2241,7 +2277,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2254,7 +2292,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2272,7 +2312,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.