summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-09-15 09:54:13 -0400
committerGitHub <noreply@github.com>2021-09-15 09:54:13 -0400
commit3eba047d388fd0d798229a0779f343dbda8a2887 (patch)
tree991e2bdf96eec08a830ae5542f3368f656c0781a
parentAdd missing type hints to non-client REST servlets. (#10817) (diff)
downloadsynapse-3eba047d388fd0d798229a0779f343dbda8a2887.tar.xz
Add type hints to state database module. (#10823)
-rw-r--r--changelog.d/10823.misc1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/storage/databases/state/bg_updates.py60
-rw-r--r--synapse/storage/databases/state/store.py136
-rw-r--r--synapse/storage/state.py3
-rw-r--r--synapse/util/caches/dictionary_cache.py4
6 files changed, 133 insertions, 72 deletions
diff --git a/changelog.d/10823.misc b/changelog.d/10823.misc
new file mode 100644
index 0000000000..0532969900
--- /dev/null
+++ b/changelog.d/10823.misc
@@ -0,0 +1 @@
+Add type hints to the state database.
diff --git a/mypy.ini b/mypy.ini
index e9052fa01b..b21e1555ab 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -60,6 +60,7 @@ files =
   synapse/storage/databases/main/session.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
+  synapse/storage/databases/state,
   synapse/storage/database.py,
   synapse/storage/engines,
   synapse/storage/keys.py,
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb07f..eb1118d2cb 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
     updates.
     """
 
-    def _count_state_group_hops_txn(self, txn, state_group):
+    def _count_state_group_hops_txn(
+        self, txn: LoggingTransaction, state_group: int
+    ) -> int:
         """Given a state group, count how many hops there are in the tree.
 
         This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
         else:
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
+            next_group: Optional[int] = state_group
             count = 0
 
             while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             return count
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter: Optional[StateFilter] = None
-    ):
+        self,
+        txn: LoggingTransaction,
+        groups: List[int],
+        state_filter: Optional[StateFilter] = None,
+    ) -> Mapping[int, StateMap[str]]:
         state_filter = state_filter or StateFilter.all()
 
-        results = {group: {} for group in groups}
+        results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
 
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             """
 
             for group in groups:
-                args = [group]
+                args: List[Union[int, str]] = [group]
                 args.extend(where_args)
 
                 txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
-                next_group = group
+                next_group: Optional[int] = group
 
                 while next_group:
                     # We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                         allow_none=True,
                     )
 
+        # The results shouldn't be considered mutable.
         return results
 
 
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             columns=["room_id"],
         )
 
-    async def _background_deduplicate_state(self, progress, batch_size):
+    async def _background_deduplicate_state(
+        self, progress: dict, batch_size: int
+    ) -> int:
         """This background update will slowly deduplicate state by reencoding
         them as deltas.
         """
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             )
             max_group = rows[0][0]
 
-        def reindex_txn(txn):
+        def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
             new_last_state_group = last_state_group
             for count in range(batch_size):
                 txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                     " WHERE id < ? AND room_id = ?",
                     (state_group, room_id),
                 )
-                (prev_group,) = txn.fetchone()
+                # There will be a result due to the coalesce.
+                (prev_group,) = txn.fetchone()  # type: ignore
                 new_last_state_group = state_group
 
                 if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
                         # otherwise read performance degrades.
                         continue
 
-                    prev_state = self._get_state_groups_from_groups_txn(
+                    prev_state_by_group = self._get_state_groups_from_groups_txn(
                         txn, [prev_group]
                     )
-                    prev_state = prev_state[prev_group]
+                    prev_state = prev_state_by_group[prev_group]
 
-                    curr_state = self._get_state_groups_from_groups_txn(
+                    curr_state_by_group = self._get_state_groups_from_groups_txn(
                         txn, [state_group]
                     )
-                    curr_state = curr_state[state_group]
+                    curr_state = curr_state_by_group[state_group]
 
                     if not set(prev_state.keys()) - set(curr_state.keys()):
                         # We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
 
         return result * BATCH_SIZE_SCALE_FACTOR
 
-    async def _background_index_state(self, progress, batch_size):
-        def reindex_txn(conn):
+    async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+        def reindex_txn(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
             if isinstance(self.database_engine, PostgresEngine):
                 # postgres insists on autocommit for the index
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..f1e3a27e63 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,43 +13,56 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+
+import attr
 
 from synapse.api.constants import EventTypes
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 MAX_STATE_DELTA_HOPS = 100
 
 
-class _GetStateGroupDelta(
-    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _GetStateGroupDelta:
     """Return type of get_state_group_delta that implements __len__, which lets
-    us use the itrable flag when caching
+    us use the iterable flag when caching
     """
 
-    __slots__ = []
+    prev_group: Optional[int]
+    delta_ids: Optional[StateMap[str]]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.delta_ids) if self.delta_ids else 0
 
 
 class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
     """A data store for fetching/storing state groups."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
@@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         # We size the non-members cache to be smaller than the members cache as the
         # vast majority of state in Matrix (today) is member events.
 
-        self._state_group_cache = DictionaryCache(
+        self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
             "*stateGroupCache*",
             # TODO: this hasn't been tuned yet
             50000,
         )
-        self._state_group_members_cache = DictionaryCache(
+        self._state_group_members_cache: DictionaryCache[
+            int, StateKey, str
+        ] = DictionaryCache(
             "*stateGroupMembersCache*",
             500000,
         )
 
-        def get_max_state_group_txn(txn: Cursor):
+        def get_max_state_group_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
-            return txn.fetchone()[0]
+            return txn.fetchone()[0]  # type: ignore
 
         self._state_group_seq_gen = build_sequence_generator(
             db_conn,
@@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @cached(max_entries=10000, iterable=True)
-    async def get_state_group_delta(self, state_group):
+    async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            (prev_group, delta_ids), where both may be None.
+            _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
         """
 
-        def _get_state_group_delta_txn(txn):
+        def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
             prev_group = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 table="state_group_edges",
@@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Returns:
             Dict of state group to state map.
         """
-        results = {}
+        results: Dict[int, StateMap[str]] = {}
 
         chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
@@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return results
 
-    def _get_state_for_group_using_cache(self, cache, group, state_filter):
+    def _get_state_for_group_using_cache(
+        self,
+        cache: DictionaryCache[int, StateKey, str],
+        group: int,
+        state_filter: StateFilter,
+    ) -> Tuple[MutableStateMap[str], bool]:
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Args:
-            cache(DictionaryCache): the state group cache to use
-            group(int): The state group to lookup
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            cache: the state group cache to use
+            group: The state group to lookup
+            state_filter: The state filter used to fetch state from the database.
 
-        Returns 2-tuple (`state_dict`, `got_all`).
-        `got_all` is a bool indicating if we successfully retrieved all
-        requests state from the cache, if False we need to query the DB for the
-        missing state.
+        Returns:
+             2-tuple (`state_dict`, `got_all`).
+                `got_all` is a bool indicating if we successfully retrieved all
+                requests state from the cache, if False we need to query the DB for the
+                missing state.
         """
         cache_entry = cache.get(group)
         state_dict_ids = cache_entry.value
@@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         return state
 
     def _get_state_for_groups_using_cache(
-        self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
-    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+        self,
+        groups: Iterable[int],
+        cache: DictionaryCache[int, StateKey, str],
+        state_filter: StateFilter,
+    ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
@@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
     def _insert_into_cache(
         self,
-        group_to_state_dict,
-        state_filter,
-        cache_seq_num_members,
-        cache_seq_num_non_members,
-    ):
+        group_to_state_dict: Dict[int, StateMap[str]],
+        state_filter: StateFilter,
+        cache_seq_num_members: int,
+        cache_seq_num_non_members: int,
+    ) -> None:
         """Inserts results from querying the database into the relevant cache.
 
         Args:
-            group_to_state_dict (dict): The new entries pulled from database.
+            group_to_state_dict: The new entries pulled from database.
                 Map from state group to state dict
-            state_filter (StateFilter): The state filter used to fetch state
+            state_filter: The state filter used to fetch state
                 from the database.
-            cache_seq_num_members (int): Sequence number of member cache since
+            cache_seq_num_members: Sequence number of member cache since
                 last lookup in cache
-            cache_seq_num_non_members (int): Sequence number of member cache since
+            cache_seq_num_non_members: Sequence number of member cache since
                 last lookup in cache
         """
 
@@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             The state group ID
         """
 
-        def _store_state_group_txn(txn):
+        def _store_state_group_txn(txn: LoggingTransaction) -> int:
             if current_state_ids is None:
                 # AFAIK, this can never happen
                 raise Exception("current_state_ids cannot be None")
@@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
                 potential_hops = self._count_state_group_hops_txn(txn, prev_group)
             if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                assert delta_ids is not None
+
                 self.db_pool.simple_insert_txn(
                     txn,
                     table="state_group_edges",
@@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     async def purge_unreferenced_state_groups(
-        self, room_id: str, state_groups_to_delete
+        self, room_id: str, state_groups_to_delete: Collection[int]
     ) -> None:
         """Deletes no longer referenced state groups and de-deltas any state
         groups that reference them.
@@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Args:
             room_id: The room the state groups belong to (must all be in the
                 same room).
-            state_groups_to_delete (Collection[int]): Set of all state groups
-                to delete.
+            state_groups_to_delete: Set of all state groups to delete.
         """
 
         await self.db_pool.runInteraction(
@@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete,
         )
 
-    def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+    def _purge_unreferenced_state_groups(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        state_groups_to_delete: Collection[int],
+    ) -> None:
         logger.info(
             "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
@@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         # groups to non delta versions.
         for sg in remaining_state_groups:
             logger.info("[purge] de-delta-ing remaining state group %s", sg)
-            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
-            curr_state = curr_state[sg]
+            curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
+            curr_state = curr_state_by_group[sg]
 
             self.db_pool.simple_delete_txn(
                 txn, table="state_groups_state", keyvalues={"state_group": sg}
@@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return {row["state_group"]: row["prev_state_group"] for row in rows}
 
-    async def purge_room_state(self, room_id, state_groups_to_delete):
+    async def purge_room_state(
+        self, room_id: str, state_groups_to_delete: Collection[int]
+    ) -> None:
         """Deletes all record of a room from state tables
 
         Args:
-            room_id (str):
-            state_groups_to_delete (list[int]): State groups to delete
+            room_id:
+            state_groups_to_delete: State groups to delete
         """
 
         await self.db_pool.runInteraction(
@@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete,
         )
 
-    def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+    def _purge_room_state_txn(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        state_groups_to_delete: Collection[int],
+    ) -> None:
         # first we have to delete the state groups states
         logger.info("[purge] removing %s from state_groups_state", room_id)
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c76529cb57..5e86befde4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -377,7 +377,8 @@ class StateGroupStorage:
             make up the delta between the old and new state groups.
         """
 
-        return await self.stores.state.get_state_group_delta(state_group)
+        state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+        return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
         self, _room_id: str, event_ids: Iterable[str]
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index ade088aae2..485ddb1893 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -130,7 +130,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         sequence: int,
         key: KT,
         value: Dict[DKT, DV],
-        fetched_keys: Optional[Set[DKT]] = None,
+        fetched_keys: Optional[Iterable[DKT]] = None,
     ) -> None:
         """Updates the entry in the cache
 
@@ -155,7 +155,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
                 self._update_or_insert(key, value, fetched_keys)
 
     def _update_or_insert(
-        self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
+        self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
     ) -> None:
         # We pop and reinsert as we need to tell the cache the size may have
         # changed