summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_federation.py187
1 files changed, 123 insertions, 64 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4710224708..dcfe8caf47 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 import attr
 from prometheus_client import Counter, Gauge
@@ -33,7 +43,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
-        room = await self.get_room(room_id)
+        room = await self.get_room(room_id)  # type: ignore[attr-defined]
         if room["has_auth_chain_index"]:
             try:
                 return await self.db_pool.runInteraction(
@@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_ids_using_cover_index_txn(
-        self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        event_ids: Collection[str],
+        include_given: bool,
     ) -> Set[str]:
         """Calculates the auth chain IDs using the chain index."""
 
@@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
-        for batch in batch_iter(event_chains, 1000):
+        for batch2 in batch_iter(event_chains, 1000):
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch
+                txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
@@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         front = set(event_ids)
         while front:
-            new_front = set()
+            new_front: Set[str] = set()
             for chunk in batch_iter(front, 100):
                 # Pull the auth events either from the cache or DB.
                 to_fetch: List[str] = []  # Event IDs to fetch from DB
@@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
                     # Note we need to batch up the results by event ID before
                     # adding to the cache.
-                    to_cache = {}
+                    to_cache: Dict[str, List[Tuple[str, int]]] = {}
                     for event_id, auth_event_id, auth_event_depth in txn:
                         to_cache.setdefault(event_id, []).append(
                             (auth_event_id, auth_event_depth)
@@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
-        room = await self.get_room(room_id)
+        room = await self.get_room(room_id)  # type: ignore[attr-defined]
         if room["has_auth_chain_index"]:
             try:
                 return await self.db_pool.runInteraction(
@@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_difference_using_cover_index_txn(
-        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+        self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
     ) -> Set[str]:
         """Calculates the auth chain difference using the chain index.
 
@@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # (We need to take a copy of `seen_chains` as we want to mutate it in
         # the loop)
-        for batch in batch_iter(set(seen_chains), 1000):
+        for batch2 in batch_iter(set(seen_chains), 1000):
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch
+                txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
@@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return result
 
     def _get_auth_chain_difference_txn(
-        self, txn, state_sets: List[Set[str]]
+        self, txn: LoggingTransaction, state_sets: List[Set[str]]
     ) -> Set[str]:
         """Calculates the auth chain difference using a breadth first search.
 
@@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             # I think building a temporary list with fetchall is more efficient than
             # just `search.extend(txn)`, but this is unconfirmed
-            search.extend(txn.fetchall())
+            search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
 
         # sort by depth
         search.sort()
@@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 # We parse the results and add the to the `found` set and the
                 # cache (note we need to batch up the results by event ID before
                 # adding to the cache).
-                to_cache = {}
+                to_cache: Dict[str, List[Tuple[str, int]]] = {}
                 for event_id, auth_event_id, auth_event_depth in txn:
                     to_cache.setdefault(event_id, []).append(
                         (auth_event_id, auth_event_depth)
@@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
     async def get_oldest_event_ids_with_depth_in_room(
-        self, room_id
+        self, room_id: str
     ) -> List[Tuple[str, int]]:
         """Gets the oldest events(backwards extremities) in the room along with the
         aproximate depth.
@@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             List of (event_id, depth) tuples
         """
 
-        def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+        def get_oldest_event_ids_with_depth_in_room_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> List[Tuple[str, int]]:
             # Assemble a dictionary with event_id -> depth for the oldest events
             # we know of in the room. Backwards extremeties are the oldest
             # events we know of in the room but we only know of them because
@@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id, False))
 
-            return txn.fetchall()
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_oldest_event_ids_with_depth_in_room",
@@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     async def get_insertion_event_backward_extremities_in_room(
-        self, room_id
+        self, room_id: str
     ) -> List[Tuple[str, int]]:
         """Get the insertion events we know about that we haven't backfilled yet.
 
@@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             List of (event_id, depth) tuples
         """
 
-        def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
+        def get_insertion_event_backward_extremities_in_room_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> List[Tuple[str, int]]:
             sql = """
                 SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
                 /* We only want insertion events that are also marked as backwards extremities */
@@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
 
             txn.execute(sql, (room_id,))
-            return txn.fetchall()
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_insertion_event_backward_extremities_in_room",
@@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             return max_depth_event_id, current_max_depth
 
-    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the min depth from a set of event IDs
 
         Args:
@@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
         )
 
-    def _get_prev_events_for_room_txn(self, txn, room_id: str):
+    def _get_prev_events_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> List[str]:
         # we just use the 10 newest events. Older events will become
         # prev_events of future events.
 
@@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             sorted by extremity count.
         """
 
-        def _get_rooms_with_many_extremities_txn(txn):
+        def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
             where_clause = "1=1"
             if room_id_filter:
                 where_clause = "room_id NOT IN (%s)" % (
@@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
-    def _get_min_depth_interaction(self, txn, room_id):
+    def _get_min_depth_interaction(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> Optional[int]:
         min_depth = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="room_depth",
@@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         """
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
-        last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+        last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)  # type: ignore[attr-defined]
 
         # We don't always have a full stream_to_exterm_id table, e.g. after
         # the upgrade that introduced it, so we make sure we never ask for a
         # stream_ordering from before a restart
-        last_change = max(self._stream_order_on_start, last_change)
+        last_change = max(self._stream_order_on_start, last_change)  # type: ignore[attr-defined]
 
         # provided the last_change is recent enough, we now clamp the requested
         # stream_ordering to it.
-        if last_change > self.stream_ordering_month_ago:
+        if last_change > self.stream_ordering_month_ago:  # type: ignore[attr-defined]
             stream_ordering = min(last_change, stream_ordering)
 
         return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
 
     @cached(max_entries=5000, num_args=2)
-    async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def _get_forward_extremeties_for_room(
+        self, room_id: str, stream_ordering: int
+    ) -> List[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         stream_orderings from that point.
         """
 
-        if stream_ordering <= self.stream_ordering_month_ago:
+        if stream_ordering <= self.stream_ordering_month_ago:  # type: ignore[attr-defined]
             raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
@@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 WHERE room_id = ?
         """
 
-        def get_forward_extremeties_for_room_txn(txn):
+        def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
@@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         ]
 
     async def get_backfill_events(
-        self, room_id: str, seed_event_id_list: list, limit: int
-    ):
+        self, room_id: str, seed_event_id_list: List[str], limit: int
+    ) -> List[EventBase]:
         """Get a list of Events for a given topic that occurred before (and
         including) the events in seed_event_id_list. Return a list of max size `limit`
 
@@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         events = await self.get_events_as_list(event_ids)
         return sorted(
-            events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+            # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
+            # But it's never None, because these events were previously persisted to the DB.
+            events,
+            key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering),  # type: ignore[operator]
         )
 
-    def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+    def _get_backfill_events(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        seed_event_id_list: List[str],
+        limit: int,
+    ) -> Set[str]:
         """
         We want to make sure that we do a breadth-first, "depth" ordered search.
         We also handle navigating historical branches of history connected by
@@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             limit,
         )
 
-        event_id_results = set()
+        event_id_results: Set[str] = set()
 
         # In a PriorityQueue, the lowest valued entries are retrieved first.
         # We're using depth as the priority in the queue and tie-break based on
@@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # highest and newest-in-time message. We add events to the queue with a
         # negative depth so that we process the newest-in-time messages first
         # going backwards in time. stream_ordering follows the same pattern.
-        queue = PriorityQueue()
+        queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
 
         for seed_event_id in seed_event_id_list:
             event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return event_id_results
 
-    async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+    async def get_missing_events(
+        self,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[EventBase]:
         ids = await self.db_pool.runInteraction(
             "get_missing_events",
             self._get_missing_events,
@@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(ids)
 
-    def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+    def _get_missing_events(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[str]:
 
         seen_events = set(earliest_events)
         front = set(latest_events) - seen_events
-        event_results = []
+        event_results: List[str] = []
 
         query = (
             "SELECT prev_event_id FROM event_edges "
@@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     @wrap_as_background_process("delete_old_forward_extrem_cache")
     async def _delete_old_forward_extrem_cache(self) -> None:
-        def _delete_old_forward_extrem_cache_txn(txn):
+        def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
             # Delete entries older than a month, while making sure we don't delete
             # the only entries for a room.
             sql = """
@@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 ) AND stream_ordering < ?
             """
             txn.execute(
-                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)  # type: ignore[attr-defined]
             )
 
         await self.db_pool.runInteraction(
@@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         """
         if self.db_pool.engine.supports_returning:
 
-            def _remove_received_event_from_staging_txn(txn):
+            def _remove_received_event_from_staging_txn(
+                txn: LoggingTransaction,
+            ) -> Optional[int]:
                 sql = """
                     DELETE FROM federation_inbound_events_staging
                     WHERE origin = ? AND event_id = ?
@@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 """
 
                 txn.execute(sql, (origin, event_id))
-                return txn.fetchone()
+                row = cast(Optional[Tuple[int]], txn.fetchone())
 
-            row = await self.db_pool.runInteraction(
+                if row is None:
+                    return None
+
+                return row[0]
+
+            return await self.db_pool.runInteraction(
                 "remove_received_event_from_staging",
                 _remove_received_event_from_staging_txn,
                 db_autocommit=True,
             )
-            if row is None:
-                return None
-
-            return row[0]
 
         else:
 
-            def _remove_received_event_from_staging_txn(txn):
+            def _remove_received_event_from_staging_txn(
+                txn: LoggingTransaction,
+            ) -> Optional[int]:
                 received_ts = self.db_pool.simple_select_one_onecol_txn(
                     txn,
                     table="federation_inbound_events_staging",
@@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ) -> Optional[Tuple[str, str]]:
         """Get the next event ID in the staging area for the given room."""
 
-        def _get_next_staged_event_id_for_room_txn(txn):
+        def _get_next_staged_event_id_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, str]]:
             sql = """
                 SELECT origin, event_id
                 FROM federation_inbound_events_staging
@@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id,))
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, str]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ) -> Optional[Tuple[str, EventBase]]:
         """Get the next event in the staging area for the given room."""
 
-        def _get_next_staged_event_for_room_txn(txn):
+        def _get_next_staged_event_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, str, str]]:
             sql = """
                 SELECT event_json, internal_metadata, origin
                 FROM federation_inbound_events_staging
@@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
             txn.execute(sql, (room_id,))
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, str, str]], txn.fetchone())
 
         row = await self.db_pool.runInteraction(
             "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @wrap_as_background_process("_get_stats_for_federation_staging")
-    async def _get_stats_for_federation_staging(self):
+    async def _get_stats_for_federation_staging(self) -> None:
         """Update the prometheus metrics for the inbound federation staging area."""
 
-        def _get_stats_for_federation_staging_txn(txn):
+        def _get_stats_for_federation_staging_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, int]:
             txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
 
             txn.execute(
                 "SELECT min(received_ts) FROM federation_inbound_events_staging"
             )
 
-            (received_ts,) = txn.fetchone()
+            (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
 
             # If there is nothing in the staging area default it to 0.
             age = 0
@@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
-    async def clean_room_for_join(self, room_id):
-        return await self.db_pool.runInteraction(
+    async def clean_room_for_join(self, room_id: str) -> None:
+        await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
-    def _clean_room_for_join_txn(self, txn, room_id):
+    def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
 
         txn.execute(query, (room_id,))
         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-    async def _background_delete_non_state_event_auth(self, progress, batch_size):
-        def delete_event_auth(txn):
+    async def _background_delete_non_state_event_auth(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def delete_event_auth(txn: LoggingTransaction) -> bool:
             target_min_stream_id = progress.get("target_min_stream_id_inclusive")
             max_stream_id = progress.get("max_stream_id_exclusive")