summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11575.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/handlers/room_member.py6
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/room.py173
5 files changed, 108 insertions, 77 deletions
diff --git a/changelog.d/11575.misc b/changelog.d/11575.misc
new file mode 100644
index 0000000000..d451940bf2
--- /dev/null
+++ b/changelog.d/11575.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index e38ad635aa..cbe1e8302c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -37,7 +37,6 @@ exclude = (?x)
    |synapse/storage/databases/main/purge_events.py
    |synapse/storage/databases/main/push_rule.py
    |synapse/storage/databases/main/receipts.py
-   |synapse/storage/databases/main/room.py
    |synapse/storage/databases/main/roommember.py
    |synapse/storage/databases/main/search.py
    |synapse/storage/databases/main/state.py
@@ -205,6 +204,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.events_worker]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.room]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.room_batch]
 disallow_untyped_defs = True
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 447e3ce571..6aa910dd10 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # Add new room to the room directory if the old room was there
         # Remove old room from the room directory
         old_room = await self.store.get_room(old_room_id)
-        if old_room and old_room["is_public"]:
+        if old_room is not None and old_room["is_public"]:
             await self.store.set_room_is_public(old_room_id, False)
             await self.store.set_room_is_public(room_id, True)
 
@@ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
         for group_id in local_group_ids:
             # Add new the new room to those groups
-            await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+            await self.store.add_room_to_group(
+                group_id, room_id, old_room is not None and old_room["is_public"]
+            )
 
             # Remove the old room from those groups
             await self.store.remove_room_from_group(group_id, old_room_id)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 716b25dd34..a594223fc6 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -149,7 +149,6 @@ class DataStore(
             ],
         )
 
-        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._group_updates_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6cf6cc8484..4472335af9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -17,7 +17,7 @@ import collections
 import logging
 from abc import abstractmethod
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
 
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
@@ -29,8 +29,9 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -75,7 +76,7 @@ class RoomSortOrder(Enum):
     STATE_EVENTS = "state_events"
 
 
-class RoomWorkerStore(SQLBaseStore):
+class RoomWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore):
         room_creator_user_id: str,
         is_public: bool,
         room_version: RoomVersion,
-    ):
+    ) -> None:
         """Stores a room.
 
         Args:
@@ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
 
-    async def get_room(self, room_id: str) -> dict:
+    async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
         """Retrieve a room.
 
         Args:
@@ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore):
             A dict containing the room information, or None if the room is unknown.
         """
 
-        def get_room_with_stats_txn(txn, room_id):
+        def get_room_with_stats_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> Optional[Dict[str, Any]]:
             sql = """
                 SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore):
             ignore_non_federatable: If true filters out non-federatable rooms
         """
 
-        def _count_public_rooms_txn(txn):
+        def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
             if network_tuple:
@@ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore):
             }
 
             txn.execute(sql, query_args)
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         return await self.db_pool.runInteraction(
             "count_public_rooms", _count_public_rooms_txn
@@ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore):
     async def get_room_count(self) -> int:
         """Retrieve the total number of rooms."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             sql = "SELECT count(*)  FROM rooms"
             txn.execute(sql)
-            row = txn.fetchone()
-            return row[0] or 0
+            row = cast(Tuple[int], txn.fetchone())
+            return row[0]
 
         return await self.db_pool.runInteraction("get_rooms", f)
 
@@ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore):
         bounds: Optional[Tuple[int, str]],
         forwards: bool,
         ignore_non_federatable: bool = False,
-    ):
+    ) -> List[Dict[str, Any]]:
         """Gets the largest public rooms (where largest is in terms of joined
         members, as tracked in the statistics table).
 
@@ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
 
-        def _get_largest_public_rooms_txn(txn):
+        def _get_largest_public_rooms_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Any]]:
             txn.execute(sql, query_args)
 
             results = self.db_pool.cursor_to_dict(txn)
@@ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore):
         """
         # Filter room names by a string
         where_statement = ""
-        search_pattern = []
+        search_pattern: List[object] = []
         if search_term:
             where_statement = """
                 WHERE LOWER(state.name) LIKE ?
@@ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore):
             where_statement,
         )
 
-        def _get_rooms_paginate_txn(txn):
+        def _get_rooms_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
             # Add the search term into the WHERE clause
             # and execute the data query
             txn.execute(info_sql, search_pattern + [limit, start])
@@ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore):
             # Add the search term into the WHERE clause if present
             txn.execute(count_sql, search_pattern)
 
-            room_count = txn.fetchone()
+            room_count = cast(Tuple[int], txn.fetchone())
             return rooms, room_count[0]
 
         return await self.db_pool.runInteraction(
@@ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore):
             burst_count: How many actions that can be performed before being limited.
         """
 
-        def set_ratelimit_txn(txn):
+        def set_ratelimit_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="ratelimit_override",
@@ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore):
             user_id: user ID of the user
         """
 
-        def delete_ratelimit_txn(txn):
+        def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="ratelimit_override",
@@ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore):
         await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
 
     @cached()
-    async def get_retention_policy_for_room(self, room_id):
+    async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
         """Get the retention policy for a given room.
 
         If no retention policy has been found for this room, returns a policy defined
@@ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore):
         configuration).
 
         Args:
-            room_id (str): The ID of the room to get the retention policy of.
+            room_id: The ID of the room to get the retention policy of.
 
         Returns:
-            dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+            A dict containing "min_lifetime" and "max_lifetime" for this room.
         """
 
-        def get_retention_policy_for_room_txn(txn):
+        def get_retention_policy_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Optional[int]]]:
             txn.execute(
                 """
                 SELECT min_lifetime, max_lifetime FROM room_retention
@@ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore):
                 "max_lifetime": self.config.retention.retention_default_max_lifetime,
             }
 
-        row = ret[0]
+        min_lifetime = ret[0]["min_lifetime"]
+        max_lifetime = ret[0]["max_lifetime"]
 
         # If one of the room's policy's attributes isn't defined, use the matching
         # attribute from the default policy.
         # The default values will be None if no default policy has been defined, or if one
         # of the attributes is missing from the default policy.
-        if row["min_lifetime"] is None:
-            row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
+        if min_lifetime is None:
+            min_lifetime = self.config.retention.retention_default_min_lifetime
 
-        if row["max_lifetime"] is None:
-            row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
+        if max_lifetime is None:
+            max_lifetime = self.config.retention.retention_default_max_lifetime
 
-        return row
+        return {
+            "min_lifetime": min_lifetime,
+            "max_lifetime": max_lifetime,
+        }
 
     async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
@@ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore):
             The local and remote media as a lists of the media IDs.
         """
 
-        def _get_media_mxcs_in_room_txn(txn):
+        def _get_media_mxcs_in_room_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[str], List[str]]:
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             local_media_mxcs = []
             remote_media_mxcs = []
@@ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore):
 
         logger.info("Quarantining media in room: %s", room_id)
 
-        def _quarantine_media_in_room_txn(txn):
+        def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             return self._quarantine_media_txn(
                 txn, local_mxcs, remote_mxcs, quarantined_by
@@ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore):
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
-    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+    def _get_media_mxcs_in_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> Tuple[List[str], List[Tuple[str, str]]]:
         """Retrieves all the local and remote media MXC URIs in a given room
 
-        Args:
-            txn (cursor)
-            room_id (str)
-
         Returns:
             The local and remote media as a lists of tuples where the key is
             the hostname and the value is the media ID.
@@ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore):
         logger.info("Quarantining media: %s/%s", server_name, media_id)
         is_local = server_name == self.config.server.server_name
 
-        def _quarantine_media_by_id_txn(txn):
+        def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
             local_mxcs = [media_id] if is_local else []
             remote_mxcs = [(server_name, media_id)] if not is_local else []
 
@@ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore):
             quarantined_by: The ID of the user who made the quarantine request
         """
 
-        def _quarantine_media_by_user_txn(txn):
+        def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
 
@@ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore):
             "quarantine_media_by_user", _quarantine_media_by_user_txn
         )
 
-    def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+    def _get_media_ids_by_user_txn(
+        self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
+    ) -> List[str]:
         """Retrieves local media IDs by a given user
 
         Args:
@@ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore):
 
     def _quarantine_media_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         local_mxcs: List[str],
         remote_mxcs: List[Tuple[str, str]],
         quarantined_by: Optional[str],
@@ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore):
         # set quarantine
         if quarantined_by is not None:
             sql += "AND safe_from_quarantine = ?"
-            rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+            txn.executemany(
+                sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
+            )
         # remove from quarantine
         else:
-            rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+            txn.executemany(
+                sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+            )
 
-        txn.executemany(sql, rows)
         # Note that a rowcount of -1 can be used to indicate no rows were affected.
         total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
 
@@ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore):
 
     async def get_rooms_for_retention_period_in_range(
         self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, dict]:
+    ) -> Dict[str, Dict[str, Optional[int]]]:
         """Retrieves all of the rooms within the given retention range.
 
         Optionally includes the rooms which don't have a retention policy.
@@ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore):
             "min_lifetime" (int|None), and "max_lifetime" (int|None).
         """
 
-        def get_rooms_for_retention_period_in_range_txn(txn):
+        def get_rooms_for_retention_period_in_range_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, Optional[int]]]:
             range_conditions = []
             args = []
 
@@ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
-
         self.db_pool.updates.register_background_update_handler(
             "insert_room_retention",
             self._background_insert_retention,
@@ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_populate_rooms_creator_column,
         )
 
-    async def _background_insert_retention(self, progress, batch_size):
+    async def _background_insert_retention(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Retrieves a list of all rooms within a range and inserts an entry for each of
         them into the room_retention table.
         NULLs the property's columns if missing from the retention event in the room's
@@ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         last_room = progress.get("room_id", "")
 
-        def _background_insert_retention_txn(txn):
+        def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
             txn.execute(
                 """
                 SELECT state.room_id, state.event_id, events.json
@@ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
     async def _background_add_rooms_room_version_column(
-        self, progress: dict, batch_size: int
-    ):
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to go and add room version information to `rooms`
         table from `current_state_events` table.
         """
 
         last_room_id = progress.get("room_id", "")
 
-        def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+        def _background_add_rooms_room_version_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT room_id, json FROM current_state_events
                 INNER JOIN event_json USING (room_id, event_id)
@@ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
     async def _remove_tombstoned_rooms_from_directory(
-        self, progress, batch_size
+        self, progress: JsonDict, batch_size: int
     ) -> int:
         """Removes any rooms with tombstone events from the room directory
 
@@ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         last_room = progress.get("room_id", "")
 
-        def _get_rooms(txn):
+        def _get_rooms(txn: LoggingTransaction) -> List[str]:
             txn.execute(
                 """
                 SELECT room_id
@@ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return len(rooms)
 
     @abstractmethod
-    def set_room_is_public(self, room_id, is_public):
+    def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
         # this will need to be implemented if a background update is performed with
         # existing (tombstoned, public) rooms in the database.
         #
@@ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         32-bit integer field.
         """
 
-        def process(txn: Cursor) -> int:
+        def process(txn: LoggingTransaction) -> int:
             last_room = progress.get("last_room", "")
             txn.execute(
                 """
@@ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return 0
 
     async def _background_populate_rooms_creator_column(
-        self, progress: dict, batch_size: int
-    ):
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to go and add creator information to `rooms`
         table from `current_state_events` table.
         """
 
         last_room_id = progress.get("room_id", "")
 
-        def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+        def _background_populate_rooms_creator_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT room_id, json FROM event_json
                 INNER JOIN rooms AS room USING (room_id)
@@ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
 
-class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
+        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
 
     async def upsert_room_on_join(
         self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
-    ):
+    ) -> None:
         """Ensure that the room is stored in the table
 
         Called when we join a room over federation, and overwrites any room version
@@ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
     async def maybe_store_room_on_outlier_membership(
         self, room_id: str, room_version: RoomVersion
-    ):
+    ) -> None:
         """
         When we receive an invite or any other event over federation that may relate to a room
         we are not in, store the version of the room if we don't already know the room version.
@@ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         self.hs.get_notifier().on_new_replication_data()
 
     async def set_room_is_public_appservice(
-        self, room_id, appservice_id, network_id, is_public
-    ):
+        self, room_id: str, appservice_id: str, network_id: str, is_public: bool
+    ) -> None:
         """Edit the appservice/network specific public room list.
 
         Each appservice can have a number of published room lists associated
@@ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         network.
 
         Args:
-            room_id (str)
-            appservice_id (str)
-            network_id (str)
-            is_public (bool): Whether to publish or unpublish the room from the
-                list.
+            room_id
+            appservice_id
+            network_id
+            is_public: Whether to publish or unpublish the room from the list.
         """
 
         if is_public:
@@ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             event_report: json list of information from event report
         """
 
-        def _get_event_report_txn(txn, report_id):
+        def _get_event_report_txn(
+            txn: LoggingTransaction, report_id: int
+        ) -> Optional[Dict[str, Any]]:
 
             sql = """
                 SELECT
@@ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             count: total number of event reports matching the filter criteria
         """
 
-        def _get_event_reports_paginate_txn(txn):
+        def _get_event_reports_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
             filters = []
-            args = []
+            args: List[object] = []
 
             if user_id:
                 filters.append("er.user_id LIKE ?")
@@ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 where_clause
             )
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = """
                 SELECT