diff --git a/changelog.d/11575.misc b/changelog.d/11575.misc
new file mode 100644
index 0000000000..d451940bf2
--- /dev/null
+++ b/changelog.d/11575.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index e38ad635aa..cbe1e8302c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -37,7 +37,6 @@ exclude = (?x)
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
- |synapse/storage/databases/main/room.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
@@ -205,6 +204,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.room]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 447e3ce571..6aa910dd10 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
- if old_room and old_room["is_public"]:
+ if old_room is not None and old_room["is_public"]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)
@@ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids:
# Add new the new room to those groups
- await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+ await self.store.add_room_to_group(
+ group_id, room_id, old_room is not None and old_room["is_public"]
+ )
# Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 716b25dd34..a594223fc6 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -149,7 +149,6 @@ class DataStore(
],
)
- self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6cf6cc8484..4472335af9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -17,7 +17,7 @@ import collections
import logging
from abc import abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@@ -29,8 +29,9 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -75,7 +76,7 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
-class RoomWorkerStore(SQLBaseStore):
+class RoomWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore):
room_creator_user_id: str,
is_public: bool,
room_version: RoomVersion,
- ):
+ ) -> None:
"""Stores a room.
Args:
@@ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def get_room(self, room_id: str) -> dict:
+ async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve a room.
Args:
@@ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore):
A dict containing the room information, or None if the room is unknown.
"""
- def get_room_with_stats_txn(txn, room_id):
+ def get_room_with_stats_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore):
ignore_non_federatable: If true filters out non-federatable rooms
"""
- def _count_public_rooms_txn(txn):
+ def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
if network_tuple:
@@ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore):
}
txn.execute(sql, query_args)
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
@@ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore):
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms."""
- def f(txn):
+ def f(txn: LoggingTransaction) -> int:
sql = "SELECT count(*) FROM rooms"
txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
+ row = cast(Tuple[int], txn.fetchone())
+ return row[0]
return await self.db_pool.runInteraction("get_rooms", f)
@@ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore):
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ):
+ ) -> List[Dict[str, Any]]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@@ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
- def _get_largest_public_rooms_txn(txn):
+ def _get_largest_public_rooms_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
txn.execute(sql, query_args)
results = self.db_pool.cursor_to_dict(txn)
@@ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore):
"""
# Filter room names by a string
where_statement = ""
- search_pattern = []
+ search_pattern: List[object] = []
if search_term:
where_statement = """
WHERE LOWER(state.name) LIKE ?
@@ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore):
where_statement,
)
- def _get_rooms_paginate_txn(txn):
+ def _get_rooms_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
# Add the search term into the WHERE clause
# and execute the data query
txn.execute(info_sql, search_pattern + [limit, start])
@@ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore):
# Add the search term into the WHERE clause if present
txn.execute(count_sql, search_pattern)
- room_count = txn.fetchone()
+ room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0]
return await self.db_pool.runInteraction(
@@ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore):
burst_count: How many actions that can be performed before being limited.
"""
- def set_ratelimit_txn(txn):
+ def set_ratelimit_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="ratelimit_override",
@@ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore):
user_id: user ID of the user
"""
- def delete_ratelimit_txn(txn):
+ def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
row = self.db_pool.simple_select_one_txn(
txn,
table="ratelimit_override",
@@ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached()
- async def get_retention_policy_for_room(self, room_id):
+ async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore):
configuration).
Args:
- room_id (str): The ID of the room to get the retention policy of.
+ room_id: The ID of the room to get the retention policy of.
Returns:
- dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ A dict containing "min_lifetime" and "max_lifetime" for this room.
"""
- def get_retention_policy_for_room_txn(txn):
+ def get_retention_policy_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@@ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore):
"max_lifetime": self.config.retention.retention_default_max_lifetime,
}
- row = ret[0]
+ min_lifetime = ret[0]["min_lifetime"]
+ max_lifetime = ret[0]["max_lifetime"]
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
- if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
+ if min_lifetime is None:
+ min_lifetime = self.config.retention.retention_default_min_lifetime
- if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
+ if max_lifetime is None:
+ max_lifetime = self.config.retention.retention_default_max_lifetime
- return row
+ return {
+ "min_lifetime": min_lifetime,
+ "max_lifetime": max_lifetime,
+ }
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
@@ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of the media IDs.
"""
- def _get_media_mxcs_in_room_txn(txn):
+ def _get_media_mxcs_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], List[str]]:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
@@ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media in room: %s", room_id)
- def _quarantine_media_in_room_txn(txn):
+ def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
@@ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
- def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ def _get_media_mxcs_in_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> Tuple[List[str], List[Tuple[str, str]]]:
"""Retrieves all the local and remote media MXC URIs in a given room
- Args:
- txn (cursor)
- room_id (str)
-
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
@@ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name
- def _quarantine_media_by_id_txn(txn):
+ def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else []
remote_mxcs = [(server_name, media_id)] if not is_local else []
@@ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore):
quarantined_by: The ID of the user who made the quarantine request
"""
- def _quarantine_media_by_user_txn(txn):
+ def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
@@ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
- def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ def _get_media_ids_by_user_txn(
+ self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
+ ) -> List[str]:
"""Retrieves local media IDs by a given user
Args:
@@ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_txn(
self,
- txn,
+ txn: LoggingTransaction,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
quarantined_by: Optional[str],
@@ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore):
# set quarantine
if quarantined_by is not None:
sql += "AND safe_from_quarantine = ?"
- rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ txn.executemany(
+ sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
+ )
# remove from quarantine
else:
- rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+ txn.executemany(
+ sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+ )
- txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
@@ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore):
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
- ) -> Dict[str, dict]:
+ ) -> Dict[str, Dict[str, Optional[int]]]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
@@ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore):
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
- def get_rooms_for_retention_period_in_range_txn(txn):
+ def get_rooms_for_retention_period_in_range_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, Optional[int]]]:
range_conditions = []
args = []
@@ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self.config = hs.config
-
self.db_pool.updates.register_background_update_handler(
"insert_room_retention",
self._background_insert_retention,
@@ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_populate_rooms_creator_column,
)
- async def _background_insert_retention(self, progress, batch_size):
+ async def _background_insert_retention(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@@ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "")
- def _background_insert_retention_txn(txn):
+ def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT state.room_id, state.event_id, events.json
@@ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _background_add_rooms_room_version_column(
- self, progress: dict, batch_size: int
- ):
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to go and add room version information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
- def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+ def _background_add_rooms_room_version_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT room_id, json FROM current_state_events
INNER JOIN event_json USING (room_id, event_id)
@@ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _remove_tombstoned_rooms_from_directory(
- self, progress, batch_size
+ self, progress: JsonDict, batch_size: int
) -> int:
"""Removes any rooms with tombstone events from the room directory
@@ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "")
- def _get_rooms(txn):
+ def _get_rooms(txn: LoggingTransaction) -> List[str]:
txn.execute(
"""
SELECT room_id
@@ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return len(rooms)
@abstractmethod
- def set_room_is_public(self, room_id, is_public):
+ def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
# this will need to be implemented if a background update is performed with
# existing (tombstoned, public) rooms in the database.
#
@@ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
32-bit integer field.
"""
- def process(txn: Cursor) -> int:
+ def process(txn: LoggingTransaction) -> int:
last_room = progress.get("last_room", "")
txn.execute(
"""
@@ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return 0
async def _background_populate_rooms_creator_column(
- self, progress: dict, batch_size: int
- ):
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to go and add creator information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
- def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+ def _background_populate_rooms_creator_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
sql = """
SELECT room_id, json FROM event_json
INNER JOIN rooms AS room USING (room_id)
@@ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
-class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
):
super().__init__(database, db_conn, hs)
- self.config = hs.config
+ self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
- ):
+ ) -> None:
"""Ensure that the room is stored in the table
Called when we join a room over federation, and overwrites any room version
@@ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
- ):
+ ) -> None:
"""
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
@@ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.hs.get_notifier().on_new_replication_data()
async def set_room_is_public_appservice(
- self, room_id, appservice_id, network_id, is_public
- ):
+ self, room_id: str, appservice_id: str, network_id: str, is_public: bool
+ ) -> None:
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@@ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
network.
Args:
- room_id (str)
- appservice_id (str)
- network_id (str)
- is_public (bool): Whether to publish or unpublish the room from the
- list.
+ room_id
+ appservice_id
+ network_id
+ is_public: Whether to publish or unpublish the room from the list.
"""
if is_public:
@@ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
event_report: json list of information from event report
"""
- def _get_event_report_txn(txn, report_id):
+ def _get_event_report_txn(
+ txn: LoggingTransaction, report_id: int
+ ) -> Optional[Dict[str, Any]]:
sql = """
SELECT
@@ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
count: total number of event reports matching the filter criteria
"""
- def _get_event_reports_paginate_txn(txn):
+ def _get_event_reports_paginate_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], int]:
filters = []
- args = []
+ args: List[object] = []
if user_id:
filters.append("er.user_id LIKE ?")
@@ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
where_clause
)
txn.execute(sql, args)
- count = txn.fetchone()[0]
+ count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT
|