summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12255.misc1
-rw-r--r--mypy.ini3
-rw-r--r--synapse/storage/databases/main/media_repository.py2
-rw-r--r--synapse/storage/databases/main/receipts.py37
-rw-r--r--synapse/storage/databases/main/registration.py3
-rw-r--r--synapse/storage/databases/main/room.py3
-rw-r--r--synapse/storage/databases/main/search.py26
-rw-r--r--synapse/storage/databases/main/state.py24
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/user_directory.py2
10 files changed, 61 insertions, 42 deletions
diff --git a/changelog.d/12255.misc b/changelog.d/12255.misc
new file mode 100644
index 0000000000..2b1290d1e1
--- /dev/null
+++ b/changelog.d/12255.misc
@@ -0,0 +1 @@
+Add missing type hints for storage.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index 24d4ba15d4..4781ce56f2 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -38,14 +38,11 @@ exclude = (?x)
    |synapse/_scripts/update_synapse_database.py
 
    |synapse/storage/databases/__init__.py
-   |synapse/storage/databases/main/__init__.py
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
    |synapse/storage/databases/main/event_federation.py
    |synapse/storage/databases/main/push_rule.py
-   |synapse/storage/databases/main/receipts.py
    |synapse/storage/databases/main/roommember.py
-   |synapse/storage/databases/main/search.py
    |synapse/storage/databases/main/state.py
    |synapse/storage/schema/
 
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cbba356b4a..322ed05390 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
     async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index bf0b903af2..e6f97aeece 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -24,10 +24,9 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
-from twisted.internet import defer
-
 from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
@@ -38,7 +37,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         hs: "HomeServer",
     ):
         self._instance_name = hs.get_instance_name()
+        self._receipts_id_gen: AbstractStreamIdTracker
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_receipts = (
@@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 " AND user_id = ?"
             )
             txn.execute(sql, (user_id,))
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
             "get_receipts_for_user_with_orderings", f
@@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if not rows:
             return []
 
-        content = {}
+        content: JsonDict = {}
         for row in rows:
             content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
                 row["user_id"]
@@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "_get_linearized_receipts_for_rooms", f
         )
 
-        results = {}
+        results: JsonDict = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
@@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "get_linearized_receipts_for_all_rooms", f
         )
 
-        results = {}
+        results: JsonDict = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
@@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
 
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
             sql = """
@@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             """
             txn.execute(sql, (last_id, current_id, limit))
 
-            updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+            )
 
             limited = False
             upper_bound = current_id
@@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == ReceiptsStream.NAME:
             self._receipts_id_gen.advance(instance_name, token)
             for row in rows:
@@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
         if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
-            self._remove_old_push_actions_before_txn(
+            self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
 
@@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        async with self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a698d10cc5..7f3d190e94 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -22,6 +22,7 @@ import attr
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.config.homeserver import HomeServerConfig
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
     DatabasePool,
@@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
+        self.config: HomeServerConfig = hs.config
 
         # Note: we don't check this sequence for consistency as we'd have to
         # call `find_max_generated_user_id_localpart` each time, which is
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 94068940b9..18b1acd9e1 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -34,6 +34,7 @@ import attr
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
+        self.config: HomeServerConfig = hs.config
 
     async def store_room(
         self,
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index c5e9010c83..d4482c06db 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
 
 import attr
 
@@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
                 " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
             )
 
-            args = (
+            args1 = (
                 (
                     entry.event_id,
                     entry.room_id,
@@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.execute_batch(sql, args)
+            txn.execute_batch(sql, args1)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
                 "INSERT INTO event_search (event_id, room_id, key, value)"
                 " VALUES (?,?,?,?)"
             )
-            args = (
+            args2 = (
                 (
                     entry.event_id,
                     entry.room_id,
@@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
                 )
                 for entry in entries
             )
-            txn.execute_batch(sql, args)
+            txn.execute_batch(sql, args2)
 
         else:
             # This should be unreachable.
@@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         search_query = _parse_query(self.database_engine, search_term)
 
-        args = []
+        args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
@@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
         # search results (which is a data leak)
-        events = await self.get_events_as_list(
+        events = await self.get_events_as_list(  # type: ignore[attr-defined]
             [r["event_id"] for r in results],
             redact_behaviour=EventRedactBehaviour.BLOCK,
         )
@@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         room_ids: Collection[str],
         search_term: str,
         keys: Iterable[str],
-        limit,
+        limit: int,
         pagination_token: Optional[str] = None,
     ) -> JsonDict:
         """Performs a full text search over events with given keys.
@@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         search_query = _parse_query(self.database_engine, search_term)
 
-        args = []
+        args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
@@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         if pagination_token:
             try:
-                origin_server_ts, stream = pagination_token.split(",")
-                origin_server_ts = int(origin_server_ts)
-                stream = int(stream)
+                origin_server_ts_str, stream_str = pagination_token.split(",")
+                origin_server_ts = int(origin_server_ts_str)
+                stream = int(stream_str)
             except Exception:
                 raise SynapseError(400, "Invalid pagination token")
 
@@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
         # search results (which is a data leak)
-        events = await self.get_events_as_list(
+        events = await self.get_events_as_list(  # type: ignore[attr-defined]
             [r["event_id"] for r in results],
             redact_behaviour=EventRedactBehaviour.BLOCK,
         )
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 417aef1dbc..28460fd364 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import collections.abc
 import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             # We delegate to the cached version
             return await self.get_current_state_ids(room_id)
 
-        def _get_filtered_current_state_ids_txn(txn):
+        def _get_filtered_current_state_ids_txn(
+            txn: LoggingTransaction,
+        ) -> StateMap[str]:
             results = {}
             sql = """
                 SELECT type, state_key, event_id FROM current_state_events
@@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         event_id = state.get((EventTypes.CanonicalAlias, ""))
         if not event_id:
-            return
+            return None
 
         event = await self.get_event(event_id, allow_none=True)
         if not event:
-            return
+            return None
 
         return event.content.get("canonical_alias")
 
@@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         list_name="event_ids",
         num_args=1,
     )
-    async def _get_state_group_for_events(self, event_ids):
+    async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
         """Returns mapping event_id -> state_group"""
         rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
@@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
         self.db_pool.updates.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             self._background_remove_left_rooms,
         )
 
-    async def _background_remove_left_rooms(self, progress, batch_size):
+    async def _background_remove_left_rooms(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to delete rows from `current_state_events` and
         `event_forward_extremities` tables of rooms that the server is no
         longer joined to.
@@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
         last_room_id = progress.get("last_room_id", "")
 
-        def _background_remove_left_rooms_txn(txn):
+        def _background_remove_left_rooms_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[bool, Set[str]]:
             # get a batch of room ids to consider
             sql = """
                 SELECT DISTINCT room_id FROM current_state_events
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 427ae1f649..b95dbef678 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
         self.clock = self.hs.get_clock()
         self.stats_enabled = hs.config.stats.stats_enabled
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0595df01d3..df772d4721 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -68,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     ) -> None:
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
         self.db_pool.updates.register_background_update_handler(
             "populate_user_directory_createtables",