diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb07f..eb1118d2cb 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""
- def _count_state_group_hops_txn(self, txn, state_group):
+ def _count_state_group_hops_txn(
+ self, txn: LoggingTransaction, state_group: int
+ ) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
+ next_group: Optional[int] = state_group
count = 0
while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter: Optional[StateFilter] = None
- ):
+ self,
+ txn: LoggingTransaction,
+ groups: List[int],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()
- results = {group: {} for group in groups}
+ results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""
for group in groups:
- args = [group]
+ args: List[Union[int, str]] = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
- next_group = group
+ next_group: Optional[int] = group
while next_group:
# We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True,
)
+ # The results shouldn't be considered mutable.
return results
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
- async def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(
+ self, progress: dict, batch_size: int
+ ) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
max_group = rows[0][0]
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- (prev_group,) = txn.fetchone()
+ # There will be a result due to the coalesce.
+ (prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group
if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades.
continue
- prev_state = self._get_state_groups_from_groups_txn(
+ prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
- prev_state = prev_state[prev_group]
+ prev_state = prev_state_by_group[prev_group]
- curr_state = self._get_state_groups_from_groups_txn(
+ curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
- curr_state = curr_state[state_group]
+ curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR
- async def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
|