summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-12-30 08:09:53 -0500
committerGitHub <noreply@github.com>2020-12-30 08:09:53 -0500
commit637282bb5019ce1656001927eea1be46c4854815 (patch)
tree7e9a9ba64918de5e995241a02054174ced099d75
parentDoc/move database setup instructions in install md (#8987) (diff)
downloadsynapse-637282bb5019ce1656001927eea1be46c4854815.tar.xz
Add additional type hints to the storage module. (#8980)
Diffstat (limited to '')
-rw-r--r--changelog.d/8980.misc1
-rw-r--r--mypy.ini10
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/storage/__init__.py9
-rw-r--r--synapse/storage/_base.py36
-rw-r--r--synapse/storage/background_updates.py111
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/prepare_database.py104
-rw-r--r--synapse/storage/purge_events.py11
-rw-r--r--synapse/storage/relations.py44
-rw-r--r--synapse/storage/state.py35
12 files changed, 224 insertions, 148 deletions
diff --git a/changelog.d/8980.misc b/changelog.d/8980.misc
new file mode 100644
index 0000000000..83ef3c5def
--- /dev/null
+++ b/changelog.d/8980.misc
@@ -0,0 +1 @@
+Add type hints to the base storage code.
diff --git a/mypy.ini b/mypy.ini
index 1e88909d46..a54f34fe24 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -70,6 +70,9 @@ files =
   synapse/server_notices,
   synapse/spam_checker_api,
   synapse/state,
+  synapse/storage/__init__.py,
+  synapse/storage/_base.py,
+  synapse/storage/background_updates.py,
   synapse/storage/databases/main/appservice.py,
   synapse/storage/databases/main/events.py,
   synapse/storage/databases/main/pusher.py,
@@ -78,8 +81,15 @@ files =
   synapse/storage/databases/main/ui_auth.py,
   synapse/storage/database.py,
   synapse/storage/engines,
+  synapse/storage/keys.py,
   synapse/storage/persist_events.py,
+  synapse/storage/prepare_database.py,
+  synapse/storage/purge_events.py,
+  synapse/storage/push_rule.py,
+  synapse/storage/relations.py,
+  synapse/storage/roommember.py,
   synapse/storage/state.py,
+  synapse/storage/types.py,
   synapse/storage/util,
   synapse/streams,
   synapse/types.py,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754bf8..fbd8df9dcc 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
-        room_state = await self.state_store.get_state_for_events([member_event_id])
-
-        room_state = room_state[member_event_id]
+        room_state = await self.state_store.get_state_for_event(member_event_id)
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9827c7eb8d..5c7590f38e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -554,7 +554,7 @@ class SyncHandler:
             event.event_id, state_filter=state_filter
         )
         if event.is_state():
-            state_ids = state_ids.copy()
+            state_ids = dict(state_ids)
             state_ids[(event.type, event.state_key)] = event.event_id
         return state_ids
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..c0d9d1240f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
 data stores associated with them (e.g. the schema version tables), which are
 stored in `synapse.storage.schema`.
 """
+from typing import TYPE_CHECKING
 
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main import DataStore
@@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
 from synapse.storage.purge_events import PurgeEventsStorage
 from synapse.storage.state import StateGroupStorage
 
-__all__ = ["DataStores", "DataStore"]
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
 
 
 class Storage:
     """The high level interfaces for talking to various storage layers.
     """
 
-    def __init__(self, hs, stores: Databases):
+    def __init__(self, hs: "HomeServer", stores: Databases):
         # We include the main data store here mainly so that we don't have to
         # rewrite all the existing code to split it into high vs low level
         # interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..a25c4093bc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
 import logging
 import random
 from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
 
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
 from synapse.util import json_decoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
         self.db_pool = database
         self.rand = random.SystemRandom()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: StreamToken,
+        rows: Iterable[Any],
+    ) -> None:
         pass
 
-    def _invalidate_state_caches(self, room_id, members_changed):
+    def _invalidate_state_caches(
+        self, room_id: str, members_changed: Iterable[str]
+    ) -> None:
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
 
         Args:
-            room_id (str): Room where state changed
-            members_changed (iterable[str]): The user_ids of members that have
-                changed
+            room_id: Room where state changed
+            members_changed: The user_ids of members that have changed
         """
         for host in {get_domain_from_id(u) for u in members_changed}:
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
-    ):
+    ) -> None:
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
             cache.invalidate(tuple(key))
 
 
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
     """
     Take some data from a database row and return a JSON-decoded object.
 
     Args:
-        db_content (memoryview|buffer|bytes|bytearray|unicode)
+        db_content: The JSON-encoded contents from the database.
+
+    Returns:
+        The object decoded from JSON.
     """
     # psycopg2 on Python 3 returns memoryview objects, which we need to
     # cast to bytes to decode
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 
 from . import engines
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.storage.database import DatabasePool, LoggingTransaction
+
 logger = logging.getLogger(__name__)
 
 
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         self.name = name
         self.total_item_count = 0
-        self.total_duration_ms = 0
-        self.avg_item_count = 0
-        self.avg_duration_ms = 0
+        self.total_duration_ms = 0.0
+        self.avg_item_count = 0.0
+        self.avg_duration_ms = 0.0
 
-    def update(self, item_count, duration_ms):
+    def update(self, item_count: int, duration_ms: float) -> None:
         """Update the stats after doing an update"""
         self.total_item_count += item_count
         self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
         self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
         self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
 
-    def average_items_per_ms(self):
+    def average_items_per_ms(self) -> Optional[float]:
         """An estimate of how long it takes to do a single update.
         Returns:
             A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
             # changes in how long the update process takes.
             return float(self.avg_item_count) / float(self.avg_duration_ms)
 
-    def total_items_per_ms(self):
+    def total_items_per_ms(self) -> Optional[float]:
         """An estimate of how long it takes to do a single update.
         Returns:
             A duration in ms as a float
@@ -83,21 +88,25 @@ class BackgroundUpdater:
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, hs, database):
+    def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
         # if a background update is currently running, its name.
         self._current_background_update = None  # type: Optional[str]
 
-        self._background_update_performance = {}
-        self._background_update_handlers = {}
+        self._background_update_performance = (
+            {}
+        )  # type: Dict[str, BackgroundUpdatePerformance]
+        self._background_update_handlers = (
+            {}
+        )  # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
         self._all_done = False
 
-    def start_doing_background_updates(self):
+    def start_doing_background_updates(self) -> None:
         run_as_background_process("background_updates", self.run_background_updates)
 
-    async def run_background_updates(self, sleep=True):
+    async def run_background_updates(self, sleep: bool = True) -> None:
         logger.info("Starting background schema updates")
         while True:
             if sleep:
@@ -148,7 +157,7 @@ class BackgroundUpdater:
 
         return False
 
-    async def has_completed_background_update(self, update_name) -> bool:
+    async def has_completed_background_update(self, update_name: str) -> bool:
         """Check if the given background update has finished running.
         """
         if self._all_done:
@@ -173,8 +182,7 @@ class BackgroundUpdater:
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms(float): How long we want to spend
-                updating.
+            desired_duration_ms: How long we want to spend updating.
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -220,6 +228,7 @@ class BackgroundUpdater:
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
+        assert self._current_background_update is not None
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
@@ -273,7 +282,11 @@ class BackgroundUpdater:
 
         return len(self._background_update_performance)
 
-    def register_background_update_handler(self, update_name, update_handler):
+    def register_background_update_handler(
+        self,
+        update_name: str,
+        update_handler: Callable[[JsonDict, int], Awaitable[int]],
+    ):
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
         The handler is responsible for updating the progress of the update.
 
         Args:
-            update_name(str): The name of the update that this code handles.
-            update_handler(function): The function that does the update.
+            update_name: The name of the update that this code handles.
+            update_handler: The function that does the update.
         """
         self._background_update_handlers[update_name] = update_handler
 
-    def register_noop_background_update(self, update_name):
+    def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
 
         This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
         also be called to clear the update.
 
         Args:
-            update_name (str): Name of update
+            update_name: Name of update
         """
 
-        async def noop_update(progress, batch_size):
+        async def noop_update(progress: JsonDict, batch_size: int) -> int:
             await self._end_background_update(update_name)
             return 1
 
@@ -313,14 +326,14 @@ class BackgroundUpdater:
 
     def register_background_index_update(
         self,
-        update_name,
-        index_name,
-        table,
-        columns,
-        where_clause=None,
-        unique=False,
-        psql_only=False,
-    ):
+        update_name: str,
+        index_name: str,
+        table: str,
+        columns: Iterable[str],
+        where_clause: Optional[str] = None,
+        unique: bool = False,
+        psql_only: bool = False,
+    ) -> None:
         """Helper for store classes to do a background index addition
 
         To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
         2. In the Store constructor, call this method
 
         Args:
-            update_name (str): update_name to register for
-            index_name (str): name of index to add
-            table (str): table to add index to
-            columns (list[str]): columns/expressions to include in index
-            unique (bool): true to make a UNIQUE index
+            update_name: update_name to register for
+            index_name: name of index to add
+            table: table to add index to
+            columns: columns/expressions to include in index
+            unique: true to make a UNIQUE index
             psql_only: true to only create this index on psql databases (useful
                 for virtual sqlite tables)
         """
 
-        def create_index_psql(conn):
+        def create_index_psql(conn: Connection) -> None:
             conn.rollback()
             # postgres insists on autocommit for the index
-            conn.set_session(autocommit=True)
+            conn.set_session(autocommit=True)  # type: ignore
 
             try:
                 c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
                 logger.debug("[SQL] %s", sql)
                 c.execute(sql)
             finally:
-                conn.set_session(autocommit=False)
+                conn.set_session(autocommit=False)  # type: ignore
 
-        def create_index_sqlite(conn):
+        def create_index_sqlite(conn: Connection) -> None:
             # Sqlite doesn't support concurrent creation of indexes.
             #
             # We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
             c.execute(sql)
 
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
-            runner = create_index_psql
+            runner = create_index_psql  # type: Optional[Callable[[Connection], None]]
         elif psql_only:
             runner = None
         else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
             "background_updates", keyvalues={"update_name": update_name}
         )
 
-    async def _background_update_progress(self, update_name: str, progress: dict):
+    async def _background_update_progress(
+        self, update_name: str, progress: dict
+    ) -> None:
         """Update the progress of a background update
 
         Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
             progress: The progress of the update.
         """
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "background_update_progress",
             self._background_update_progress_txn,
             update_name,
             progress,
         )
 
-    def _background_update_progress_txn(self, txn, update_name, progress):
+    def _background_update_progress_txn(
+        self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+    ) -> None:
         """Update the progress of a background update
 
         Args:
-            txn(cursor): The transaction.
-            update_name(str): The name of the background update task
-            progress(dict): The progress of the update.
+            txn: The transaction.
+            update_name: The name of the background update task
+            progress: The progress of the update.
         """
 
         progress_json = json_encoder.encode(progress)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@
 import logging
 
 import attr
+from signedjson.types import VerifyKey
 
 logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, frozen=True)
 class FetchKeyResult:
-    verify_key = attr.ib()  # VerifyKey: the key itself
-    valid_until_ts = attr.ib()  # int: how long we can use this key for
+    verify_key = attr.ib(type=VerifyKey)  # the key itself
+    valid_until_ts = attr.ib(type=int)  # how long we can use this key for
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..f91a2eae7a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
 import os
 import re
 from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
 
 import attr
+from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
@@ -70,7 +71,7 @@ def prepare_database(
     db_conn: LoggingDatabaseConnection,
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
-    databases: Collection[str] = ["main", "state"],
+    databases: Collection[str] = ("main", "state"),
 ):
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
@@ -155,7 +156,9 @@ def prepare_database(
         raise
 
 
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+    cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
     """Sets up the physical database by finding a base set of "full schemas" and
     then applying any necessary deltas, including schemas from the given data
     stores.
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
     folder as well those in the data stores specified.
 
     Args:
-        cur (Cursor): a database cursor
-        database_engine (DatabaseEngine)
-        databases (list[str]): The names of the databases to instantiate
-            on the given physical database.
+        cur: a database cursor
+        database_engine
+        databases: The names of the databases to instantiate on the given physical database.
     """
 
     # We're about to set up a brand new database so we check that its
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
     database_engine.check_new_database(cur)
 
     current_dir = os.path.join(dir_path, "schema", "full_schemas")
-    directory_entries = os.listdir(current_dir)
 
     # First we find the highest full schema version we have
     valid_versions = []
 
-    for filename in directory_entries:
+    for filename in os.listdir(current_dir):
         try:
             ver = int(filename)
         except ValueError:
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
         for database in databases
     )
 
-    directory_entries = []
+    directory_entries = []  # type: List[_DirectoryListing]
     for directory in directories:
         directory_entries.extend(
             _DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
 
 
 def _upgrade_existing_database(
-    cur,
-    current_version,
-    applied_delta_files,
-    upgraded,
-    database_engine,
-    config,
-    databases,
-    is_empty=False,
-):
+    cur: Cursor,
+    current_version: int,
+    applied_delta_files: List[str],
+    upgraded: bool,
+    database_engine: BaseDatabaseEngine,
+    config: Optional[HomeServerConfig],
+    databases: Collection[str],
+    is_empty: bool = False,
+) -> None:
     """Upgrades an existing physical database.
 
     Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
     for a version before applying those in the next version.
 
     Args:
-        cur (Cursor)
-        current_version (int): The current version of the schema.
-        applied_delta_files (list): A list of deltas that have already been
-            applied.
-        upgraded (bool): Whether the current version was generated by having
+        cur
+        current_version: The current version of the schema.
+        applied_delta_files: A list of deltas that have already been applied.
+        upgraded: Whether the current version was generated by having
             applied deltas or from full schema file. If `True` the function
             will never apply delta files for the given `current_version`, since
             the current_version wasn't generated by applying those delta files.
-        database_engine (DatabaseEngine)
-        config (synapse.config.homeserver.HomeServerConfig|None):
+        database_engine
+        config:
             None if we are initialising a blank database, otherwise the application
             config
-        databases (list[str]): The names of the databases to instantiate
+        databases: The names of the databases to instantiate
             on the given physical database.
-        is_empty (bool): Is this a blank database? I.e. do we need to run the
+        is_empty: Is this a blank database? I.e. do we need to run the
             upgrade portions of the delta scripts.
     """
     if is_empty:
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
     if not is_empty and "main" in databases:
         from synapse.storage.databases.main import check_database_before_upgrade
 
+        assert config is not None
         check_database_before_upgrade(cur, database_engine, config)
 
     start_ver = current_version
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
             )
 
         # Used to check if we have any duplicate file names
-        file_name_counter = Counter()
+        file_name_counter = Counter()  # type: CounterType[str]
 
         # Now find which directories have anything of interest.
-        directory_entries = []
+        directory_entries = []  # type: List[_DirectoryListing]
         for directory in directories:
             logger.debug("Looking for schema deltas in %s", directory)
             try:
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
 
                 module_name = "synapse.storage.v%d_%s" % (v, root_name)
                 with open(absolute_path) as python_file:
-                    module = imp.load_source(module_name, absolute_path, python_file)
+                    module = imp.load_source(module_name, absolute_path, python_file)  # type: ignore
                 logger.info("Running script %s", relative_path)
-                module.run_create(cur, database_engine)
+                module.run_create(cur, database_engine)  # type: ignore
                 if not is_empty:
-                    module.run_upgrade(cur, database_engine, config=config)
+                    module.run_upgrade(cur, database_engine, config=config)  # type: ignore
             elif ext == ".pyc" or file_name == "__pycache__":
                 # Sometimes .pyc files turn up anyway even though we've
                 # disabled their generation; e.g. from distribution package
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
     logger.info("Schema now up to date")
 
 
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+    txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
     """Apply the module schemas for the dynamic modules, if any
 
     Args:
         cur: database cursor
-        database_engine: synapse database engine class
-        config (synapse.config.homeserver.HomeServerConfig):
-            application config
+        database_engine:
+        config: application config
     """
     for (mod, _config) in config.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
         )
 
 
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+    cur: Cursor,
+    database_engine: BaseDatabaseEngine,
+    modname: str,
+    names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
     """Apply the module schemas for a single module
 
     Args:
         cur: database cursor
         database_engine: synapse database engine class
-        modname (str): fully qualified name of the module
-        names_and_streams (Iterable[(str, file)]): the names and streams of
-            schemas to be applied
+        modname: fully qualified name of the module
+        names_and_streams: the names and streams of schemas to be applied
     """
     cur.execute(
         "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         )
 
 
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
 
@@ -594,17 +600,19 @@ def get_statements(f):
         statement_buffer = statements[-1].strip()
 
 
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
     with open(schema_path, "r") as f:
         execute_statements_from_stream(txn, f)
 
 
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
     for statement in get_statements(f):
         cur.execute(statement)
 
 
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+    txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
     # Bluntly try creating the schema_version tables.
     schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
     executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
     current_version = int(row[0]) if row else None
-    upgraded = bool(row[1]) if row else None
 
     if current_version:
         txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
             (current_version,),
         )
         applied_deltas = [d for d, in txn]
+        upgraded = bool(row[1])
         return current_version, applied_deltas, upgraded
 
     return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
     `file_name` attr is kept first.
     """
 
-    file_name = attr.ib()
-    absolute_path = attr.ib()
+    file_name = attr.ib(type=str)
+    absolute_path = attr.ib(type=str)
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..6c359c1aae 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@
 
 import itertools
 import logging
-from typing import Set
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -24,10 +29,10 @@ class PurgeEventsStorage:
     """High level interface for purging rooms and event history.
     """
 
-    def __init__(self, hs, stores):
+    def __init__(self, hs: "HomeServer", stores: Databases):
         self.stores = stores
 
-    async def purge_room(self, room_id: str):
+    async def purge_room(self, room_id: str) -> None:
         """Deletes all record of a room
         """
 
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@
 # limitations under the License.
 
 import logging
+from typing import Any, Dict, List, Optional, Tuple
 
 import attr
 
 from synapse.api.errors import SynapseError
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -27,18 +29,18 @@ class PaginationChunk:
     """Returned by relation pagination APIs.
 
     Attributes:
-        chunk (list): The rows returned by pagination
-        next_batch (Any|None): Token to fetch next set of results with, if
+        chunk: The rows returned by pagination
+        next_batch: Token to fetch next set of results with, if
             None then there are no more results.
-        prev_batch (Any|None): Token to fetch previous set of results with, if
+        prev_batch: Token to fetch previous set of results with, if
             None then there are no previous results.
     """
 
-    chunk = attr.ib()
-    next_batch = attr.ib(default=None)
-    prev_batch = attr.ib(default=None)
+    chunk = attr.ib(type=List[JsonDict])
+    next_batch = attr.ib(type=Optional[Any], default=None)
+    prev_batch = attr.ib(type=Optional[Any], default=None)
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         d = {"chunk": self.chunk}
 
         if self.next_batch:
@@ -59,25 +61,25 @@ class RelationPaginationToken:
     boundaries of the chunk as pagination tokens.
 
     Attributes:
-        topological (int): The topological ordering of the boundary event
-        stream (int): The stream ordering of the boundary event.
+        topological: The topological ordering of the boundary event
+        stream: The stream ordering of the boundary event.
     """
 
-    topological = attr.ib()
-    stream = attr.ib()
+    topological = attr.ib(type=int)
+    stream = attr.ib(type=int)
 
     @staticmethod
-    def from_string(string):
+    def from_string(string: str) -> "RelationPaginationToken":
         try:
             t, s = string.split("-")
             return RelationPaginationToken(int(t), int(s))
         except ValueError:
             raise SynapseError(400, "Invalid token")
 
-    def to_string(self):
+    def to_string(self) -> str:
         return "%d-%d" % (self.topological, self.stream)
 
-    def as_tuple(self):
+    def as_tuple(self) -> Tuple[Any, ...]:
         return attr.astuple(self)
 
 
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
     aggregation groups, we can just use them as our pagination token.
 
     Attributes:
-        count (int): The count of relations in the boundar group.
-        stream (int): The MAX stream ordering in the boundary group.
+        count: The count of relations in the boundary group.
+        stream: The MAX stream ordering in the boundary group.
     """
 
-    count = attr.ib()
-    stream = attr.ib()
+    count = attr.ib(type=int)
+    stream = attr.ib(type=int)
 
     @staticmethod
-    def from_string(string):
+    def from_string(string: str) -> "AggregationPaginationToken":
         try:
             c, s = string.split("-")
             return AggregationPaginationToken(int(c), int(s))
         except ValueError:
             raise SynapseError(400, "Invalid token")
 
-    def to_string(self):
+    def to_string(self) -> str:
         return "%d-%d" % (self.count, self.stream)
 
-    def as_tuple(self):
+    def as_tuple(self) -> Tuple[Any, ...]:
         return attr.astuple(self)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.storage.databases import Databases
+
 logger = logging.getLogger(__name__)
 
 # Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
     """High level interface to fetching state for event.
     """
 
-    def __init__(self, hs, stores):
+    def __init__(self, hs: "HomeServer", stores: "Databases"):
         self.stores = stores
 
-    async def get_state_group_delta(self, state_group: int):
+    async def get_state_group_delta(
+        self, state_group: int
+    ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -341,8 +356,8 @@ class StateGroupStorage:
             state_group: The state group used to retrieve state deltas.
 
         Returns:
-            Tuple[Optional[int], Optional[StateMap[str]]]:
-                (prev_group, delta_ids)
+            A tuple of the previous group and a state map of the event IDs which
+            make up the delta between the old and new state groups.
         """
 
         return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
 
     async def get_state_for_events(
         self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
 
@@ -472,7 +487,7 @@ class StateGroupStorage:
 
     async def get_state_ids_for_events(
         self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
         of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
 
     async def get_state_for_event(
         self, event_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[EventBase]:
         """
         Get the state dict corresponding to a particular event
 
@@ -516,7 +531,7 @@ class StateGroupStorage:
 
     async def get_state_ids_for_event(
         self, event_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event