summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorJonathan de Jong <jonathan@automatia.nl>2021-07-15 18:46:54 +0200
committerGitHub <noreply@github.com>2021-07-15 12:46:54 -0400
commitbdfde6dca11a9468372b3c9b327ad3327cbdbe4a (patch)
treee3185688882f25f08cc0aefa80d8e1944c5004d9 /synapse/storage/databases
parentReduce likelihood of Postgres table scanning `state_groups_state`. (#10359) (diff)
downloadsynapse-bdfde6dca11a9468372b3c9b327ad3327cbdbe4a.tar.xz
Use inline type hints in `http/federation/`, `storage/` and `util/` (#10381)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/appservice.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py26
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py38
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py6
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py6
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/stream.py6
-rw-r--r--synapse/storage/databases/main/tags.py2
-rw-r--r--synapse/storage/databases/main/ui_auth.py4
13 files changed, 51 insertions, 57 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 9f182c2a89..e2d1b758bd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -48,9 +48,7 @@ def _make_exclusive_regex(
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_pattern = re.compile(
-            exclusive_user_regex
-        )  # type: Optional[Pattern]
+        exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 0e3dd4e9ca..78ae68ec68 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         txn.execute(sql, query_params)
 
-        result = {}  # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+        result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
         for (user_id, device_id, display_name, key_json) in txn:
             if include_deleted_devices:
                 deleted_devices.remove((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4e06938849..d39368c20e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             )
 
         # Cache of event ID to list of auth event IDs and their depths.
-        self._event_auth_cache = LruCache(
+        self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
             500000, "_event_auth_cache", size_callback=len
-        )  # type: LruCache[str, List[Tuple[str, int]]]
+        )
 
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
@@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         initial_events = set(event_ids)
 
         # All the events that we've found that are reachable from the events.
-        seen_events = set()  # type: Set[str]
+        seen_events: Set[str] = set()
 
         # A map from chain ID to max sequence number of the given events.
-        event_chains = {}  # type: Dict[int, int]
+        event_chains: Dict[int, int] = {}
 
         sql = """
             SELECT event_id, chain_id, sequence_number
@@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
-        chains = {}  # type: Dict[int, int]
+        chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
         for batch in batch_iter(event_chains, 1000):
@@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         initial_events = set(state_sets[0]).union(*state_sets[1:])
 
         # Map from event_id -> (chain ID, seq no)
-        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+        chain_info: Dict[str, Tuple[int, int]] = {}
 
         # Map from chain ID -> seq no -> event Id
-        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+        chain_to_event: Dict[int, Dict[int, str]] = {}
 
         # All the chains that we've found that are reachable from the state
         # sets.
-        seen_chains = set()  # type: Set[int]
+        seen_chains: Set[int] = set()
 
         sql = """
             SELECT event_id, chain_id, sequence_number
@@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         # Corresponds to `state_sets`, except as a map from chain ID to max
         # sequence number reachable from the state set.
-        set_to_chain = []  # type: List[Dict[int, int]]
+        set_to_chain: List[Dict[int, int]] = []
         for state_set in state_sets:
-            chains = {}  # type: Dict[int, int]
+            chains: Dict[int, int] = {}
             set_to_chain.append(chains)
 
             for event_id in state_set:
@@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         # Mapping from chain ID to the range of sequence numbers that should be
         # pulled from the database.
-        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+        chain_to_gap: Dict[int, Tuple[int, int]] = {}
 
         for chain_id in seen_chains:
             min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
@@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         }
 
         # The sorted list of events whose auth chains we should walk.
-        search = []  # type: List[Tuple[int, str]]
+        search: List[Tuple[int, str]] = []
 
         # We need to get the depth of the initial events for sorting purposes.
         sql = """
@@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         search.sort()
 
         # Map from event to its auth events
-        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+        event_to_auth_events: Dict[str, Set[str]] = {}
 
         base_sql = """
             SELECT a.event_id, auth_id, depth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d1237c65cc..55caa6bbe7 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # object because we might not have the same amount of rows in each of them. To do
         # this, we use a dict indexed on the user ID and room ID to make it easier to
         # populate.
-        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
         for row in txn:
             summaries[(row[0], row[1])] = _EventPushSummary(
                 unread_count=row[2],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 08c580b0dc..ec8579b9ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -109,10 +109,8 @@ class PersistEventsStore:
 
         # Ideally we'd move these ID gens here, unfortunately some other ID
         # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen = (
-            self.store._backfill_id_gen
-        )  # type: MultiWriterIdGenerator
-        self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
+        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
 
         # This should only exist on instances that are configured to write
         assert (
@@ -221,7 +219,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []  # type: List[str]
+        results: List[str] = []
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -508,7 +506,7 @@ class PersistEventsStore:
         """
 
         # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+        chain_map: Dict[str, Tuple[int, int]] = {}
 
         # Set of event IDs to calculate chain ID/seq numbers for.
         events_to_calc_chain_id_for = set(event_to_room_id)
@@ -817,8 +815,8 @@ class PersistEventsStore:
         #      new chain if the sequence number has already been allocated.
         #
 
-        existing_chains = set()  # type: Set[int]
-        tree = []  # type: List[Tuple[str, Optional[str]]]
+        existing_chains: Set[int] = set()
+        tree: List[Tuple[str, Optional[str]]] = []
 
         # We need to do this in a topologically sorted order as we want to
         # generate chain IDs/sequence numbers of an event's auth events before
@@ -848,7 +846,7 @@ class PersistEventsStore:
         )
         txn.execute(sql % (clause,), args)
 
-        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+        chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
 
         # Allocate the new events chain ID/sequence numbers.
         #
@@ -858,8 +856,8 @@ class PersistEventsStore:
         # number of new chain IDs in one call, replacing all temporary
         # objects with real allocated chain IDs.
 
-        unallocated_chain_ids = set()  # type: Set[object]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        unallocated_chain_ids: Set[object] = set()
+        new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
         for event_id, auth_event_id in tree:
             # If we reference an auth_event_id we fetch the allocated chain ID,
             # either from the existing `chain_map` or the newly generated
@@ -870,7 +868,7 @@ class PersistEventsStore:
                 if not existing_chain_id:
                     existing_chain_id = chain_map[auth_event_id]
 
-            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            new_chain_tuple: Optional[Tuple[Any, int]] = None
             if existing_chain_id:
                 # We found a chain ID/sequence number candidate, check its
                 # not already taken.
@@ -897,9 +895,9 @@ class PersistEventsStore:
         )
 
         # Map from potentially temporary chain ID to real chain ID
-        chain_id_to_allocated_map = dict(
+        chain_id_to_allocated_map: Dict[Any, int] = dict(
             zip(unallocated_chain_ids, newly_allocated_chain_ids)
-        )  # type: Dict[Any, int]
+        )
         chain_id_to_allocated_map.update((c, c) for c in existing_chains)
 
         return {
@@ -1175,9 +1173,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = (
-            OrderedDict()
-        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
+        new_events_and_contexts: OrderedDict[
+            str, Tuple[EventBase, EventContext]
+        ] = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -1205,7 +1203,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}  # type: Dict[str, int]
+        depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1885,7 +1883,7 @@ class PersistEventsStore:
                 ),
             )
 
-            room_to_event_ids = {}  # type: Dict[str, List[str]]
+            room_to_event_ids: Dict[str, List[str]] = {}
             for e, _ in events_and_contexts:
                 room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
 
@@ -2012,7 +2010,7 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}  # type: Dict[str, List[EventBase]]
+        events_by_room: Dict[str, List[EventBase]] = {}
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 29f33bac55..6fcb2b8353 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         event_to_types = {row[0]: (row[1], row[2]) for row in rows}
 
         # Calculate the new last position we've processed up to.
-        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
-        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
-        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+        new_last_depth: int = rows[-1][3] if rows else last_depth
+        new_last_stream: int = rows[-1][4] if rows else last_stream
+        new_last_room_id: str = rows[-1][5] if rows else ""
 
         # Map from room_id to last depth/stream_ordering processed for the room,
         # excluding the last room (which we're likely still processing). We also
@@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             retcols=("event_id", "auth_id"),
         )
 
-        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        event_to_auth_chain: Dict[str, List[str]] = {}
         for row in auth_events:
             event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 403a5ddaba..3c86adab56 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows = await self.db_pool.runInteraction(
+        rows: List[Tuple] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
-        )  # type: List[Tuple]
+        )
 
         # if we've got fewer rows than the limit, we're good
         if len(rows) < target_row_count:
@@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
         """
 
         mapping = {}
-        txn_id_to_event = {}  # type: Dict[Tuple[str, int, str], str]
+        txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
 
         for event in events:
             token_id = getattr(event.internal_metadata, "token_id", None)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index eb4841830d..664c65dac5 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         logger.info("[purge] looking for events to delete")
 
         should_delete_expr = "state_key IS NULL"
-        should_delete_params = ()  # type: Tuple[Any, ...]
+        should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index db52176337..a7fb8cd848 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -79,9 +79,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = StreamIdGenerator(
-                db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen: Union[
+                StreamIdGenerator, SlavedIdTracker
+            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e31c5864ac..6ad1a0cf7f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
-            values = [v for _, v in items]  # type: List[Union[str, int]]
+            values: List[Union[str, int]] = [v for _, v in items]
             # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
             # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
             # clause and values before we handle that. This seems to be only used in the "set password" handler.
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7581c7d3ff..959f13de47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
         # then filtering the results.
         if from_token.topological is not None:
-            from_bound = (
-                from_token.as_historical_tuple()
-            )  # type: Tuple[Optional[int], int]
+            from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
         elif direction == "b":
             from_bound = (
                 None,
@@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 from_token.stream,
             )
 
-        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        to_bound: Optional[Tuple[Optional[int], int]] = None
         if to_token:
             if to_token.topological is not None:
                 to_bound = to_token.as_historical_tuple()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 1d62c6140f..f93ff0a545 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        tags_by_room = {}  # type: Dict[str, Dict[str, JsonDict]]
+        tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
         for row in rows:
             room_tags = tags_by_room.setdefault(row["room_id"], {})
             room_tags[row["tag"]] = db_to_json(row["content"])
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 22c05cdde7..38bfdf5dad 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ):
         # Get the current value.
-        result = self.db_pool.simple_select_one_txn(
+        result: Dict[str, Any] = self.db_pool.simple_select_one_txn(  # type: ignore
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
             retcols=("serverdict",),
-        )  # type: Dict[str, Any]  # type: ignore
+        )
 
         # Update it and add it back to the database.
         serverdict = db_to_json(result["serverdict"])