summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-04-11 12:41:55 +0100
committerGitHub <noreply@github.com>2022-04-11 12:41:55 +0100
commit961ee75a9b0b25731eea0031b4ba99a79c050844 (patch)
tree279ae54e8e3affba8421c50e8ed3b4e618d64c39
parentMove complement setup stuff into the Synapse repo (#12404) (diff)
downloadsynapse-961ee75a9b0b25731eea0031b4ba99a79c050844.tar.xz
Disallow untyped defs in synapse._scripts (#12422)
Of note: 

* No untyped defs in `register_new_matrix_user`

This one might be contraversial. `request_registration` has three
dependency-injection arguments used for testing. I'm removing the
injection of the `requests` module and using `unitest.mock.patch` in the
test cases instead.

Doing `reveal_type(requests)` and `reveal_type(requests.get)` before the
change:

```
synapse/_scripts/register_new_matrix_user.py:45: note: Revealed type is "Any"
synapse/_scripts/register_new_matrix_user.py:46: note: Revealed type is "Any"
```

And after:

```
synapse/_scripts/register_new_matrix_user.py:44: note: Revealed type is "types.ModuleType"
synapse/_scripts/register_new_matrix_user.py:45: note: Revealed type is "def (url: Union[builtins.str, builtins.bytes], params: Union[Union[_typeshed.SupportsItems[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]], Tuple[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]], typing.Iterable[Tuple[Union[builtins.str, builtins.bytes, builtins.int, builtins.float], Union[builtins.str, builtins.bytes, builtins.int, builtins.float, typing.Iterable[Union[builtins.str, builtins.bytes, builtins.int, builtins.float]], None]]], builtins.str, builtins.bytes], None] =, data: Union[Any, None] =, headers: Union[Any, None] =, cookies: Union[Any, None] =, files: Union[Any, None] =, auth: Union[Any, None] =, timeout: Union[Any, None] =, allow_redirects: builtins.bool =, proxies: Union[Any, None] =, hooks: Union[Any, None] =, stream: Union[Any, None] =, verify: Union[Any, None] =, cert: Union[Any, None] =, json: Union[Any, None] =) -> requests.models.Response"
```

* Drive-by comment in `synapse.storage.types`

* No untyped defs in `synapse_port_db`

This was by far the most painful. I'm happy to break this up into
smaller pieces for review if it's not managable as-is.
Diffstat (limited to '')
-rw-r--r--changelog.d/12422.misc1
-rw-r--r--mypy.ini3
-rwxr-xr-xsynapse/_scripts/export_signing_key.py11
-rwxr-xr-xsynapse/_scripts/generate_config.py2
-rwxr-xr-xsynapse/_scripts/generate_log_config.py2
-rwxr-xr-xsynapse/_scripts/generate_signing_key.py2
-rwxr-xr-xsynapse/_scripts/hash_password.py4
-rwxr-xr-xsynapse/_scripts/move_remote_media_to_new_store.py19
-rw-r--r--synapse/_scripts/register_new_matrix_user.py3
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py221
-rwxr-xr-xsynapse/_scripts/synctl.py10
-rwxr-xr-xsynapse/_scripts/update_synapse_database.py20
-rw-r--r--synapse/storage/types.py1
-rw-r--r--tests/scripts/test_new_matrix_user.py62
14 files changed, 221 insertions, 140 deletions
diff --git a/changelog.d/12422.misc b/changelog.d/12422.misc
new file mode 100644
index 0000000000..3a7cbc34e7
--- /dev/null
+++ b/changelog.d/12422.misc
@@ -0,0 +1 @@
+Make `synapse._scripts` pass type checks.
diff --git a/mypy.ini b/mypy.ini
index c11386b89a..4ccea6fa5a 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -93,6 +93,9 @@ exclude = (?x)
    |tests/utils.py
    )$
 
+[mypy-synapse._scripts.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.api.*]
 disallow_untyped_defs = True
 
diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py
index 66481533e9..12c890bdbd 100755
--- a/synapse/_scripts/export_signing_key.py
+++ b/synapse/_scripts/export_signing_key.py
@@ -15,19 +15,19 @@
 import argparse
 import sys
 import time
-from typing import Optional
+from typing import NoReturn, Optional
 
 from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
 from signedjson.types import VerifyKey
 
 
-def exit(status: int = 0, message: Optional[str] = None):
+def exit(status: int = 0, message: Optional[str] = None) -> NoReturn:
     if message:
         print(message, file=sys.stderr)
     sys.exit(status)
 
 
-def format_plain(public_key: VerifyKey):
+def format_plain(public_key: VerifyKey) -> None:
     print(
         "%s:%s %s"
         % (
@@ -38,7 +38,7 @@ def format_plain(public_key: VerifyKey):
     )
 
 
-def format_for_config(public_key: VerifyKey, expiry_ts: int):
+def format_for_config(public_key: VerifyKey, expiry_ts: int) -> None:
     print(
         '  "%s:%s": { key: "%s", expired_ts: %i }'
         % (
@@ -50,7 +50,7 @@ def format_for_config(public_key: VerifyKey, expiry_ts: int):
     )
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
@@ -94,7 +94,6 @@ def main():
                 message="Error reading key from file %s: %s %s"
                 % (file.name, type(e), e),
             )
-            res = []
         for key in res:
             formatter(get_verify_key(key))
 
diff --git a/synapse/_scripts/generate_config.py b/synapse/_scripts/generate_config.py
index 75fce20b12..08eb8ef114 100755
--- a/synapse/_scripts/generate_config.py
+++ b/synapse/_scripts/generate_config.py
@@ -7,7 +7,7 @@ import sys
 from synapse.config.homeserver import HomeServerConfig
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser()
     parser.add_argument(
         "--config-dir",
diff --git a/synapse/_scripts/generate_log_config.py b/synapse/_scripts/generate_log_config.py
index 82fc763140..7ae08ec0e3 100755
--- a/synapse/_scripts/generate_log_config.py
+++ b/synapse/_scripts/generate_log_config.py
@@ -20,7 +20,7 @@ import sys
 from synapse.config.logger import DEFAULT_LOG_CONFIG
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
diff --git a/synapse/_scripts/generate_signing_key.py b/synapse/_scripts/generate_signing_key.py
index bc26d25bfd..3f8f5da75f 100755
--- a/synapse/_scripts/generate_signing_key.py
+++ b/synapse/_scripts/generate_signing_key.py
@@ -20,7 +20,7 @@ from signedjson.key import generate_signing_key, write_signing_keys
 from synapse.util.stringutils import random_string
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py
index 708640c7de..3aa29de5bd 100755
--- a/synapse/_scripts/hash_password.py
+++ b/synapse/_scripts/hash_password.py
@@ -9,7 +9,7 @@ import bcrypt
 import yaml
 
 
-def prompt_for_pass():
+def prompt_for_pass() -> str:
     password = getpass.getpass("Password: ")
 
     if not password:
@@ -23,7 +23,7 @@ def prompt_for_pass():
     return password
 
 
-def main():
+def main() -> None:
     bcrypt_rounds = 12
     password_pepper = ""
 
diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py
index f53bf790af..819afaaca6 100755
--- a/synapse/_scripts/move_remote_media_to_new_store.py
+++ b/synapse/_scripts/move_remote_media_to_new_store.py
@@ -42,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
 logger = logging.getLogger()
 
 
-def main(src_repo, dest_repo):
+def main(src_repo: str, dest_repo: str) -> None:
     src_paths = MediaFilePaths(src_repo)
     dest_paths = MediaFilePaths(dest_repo)
     for line in sys.stdin:
@@ -55,14 +55,19 @@ def main(src_repo, dest_repo):
         move_media(parts[0], parts[1], src_paths, dest_paths)
 
 
-def move_media(origin_server, file_id, src_paths, dest_paths):
+def move_media(
+    origin_server: str,
+    file_id: str,
+    src_paths: MediaFilePaths,
+    dest_paths: MediaFilePaths,
+) -> None:
     """Move the given file, and any thumbnails, to the dest repo
 
     Args:
-        origin_server (str):
-        file_id (str):
-        src_paths (MediaFilePaths):
-        dest_paths (MediaFilePaths):
+        origin_server:
+        file_id:
+        src_paths:
+        dest_paths:
     """
     logger.info("%s/%s", origin_server, file_id)
 
@@ -91,7 +96,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
     )
 
 
-def mkdir_and_move(original_file, dest_file):
+def mkdir_and_move(original_file: str, dest_file: str) -> None:
     dirname = os.path.dirname(dest_file)
     if not os.path.exists(dirname):
         logger.debug("mkdir %s", dirname)
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 4ffe6a1ef3..092601f530 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -22,7 +22,7 @@ import logging
 import sys
 from typing import Callable, Optional
 
-import requests as _requests
+import requests
 import yaml
 
 
@@ -33,7 +33,6 @@ def request_registration(
     shared_secret: str,
     admin: bool = False,
     user_type: Optional[str] = None,
-    requests=_requests,
     _print: Callable[[str], None] = print,
     exit: Callable[[int], None] = sys.exit,
 ) -> None:
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 123eaae5c5..12ff79f6e2 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -22,10 +22,26 @@ import sys
 import time
 import traceback
 from types import TracebackType
-from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    NoReturn,
+    Optional,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    cast,
+)
 
 import yaml
 from matrix_common.versionstring import get_distribution_version_string
+from typing_extensions import TypedDict
 
 from twisted.internet import defer, reactor as reactor_
 
@@ -36,7 +52,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
     run_in_background,
 )
-from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
 from synapse.storage.databases.main import PushRuleStore
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
@@ -173,6 +189,8 @@ end_error_exec_info: Optional[
     Tuple[Type[BaseException], BaseException, TracebackType]
 ] = None
 
+R = TypeVar("R")
+
 
 class Store(
     ClientIpBackgroundUpdateStore,
@@ -195,17 +213,19 @@ class Store(
     PresenceBackgroundUpdateStore,
     GroupServerWorkerStore,
 ):
-    def execute(self, f, *args, **kwargs):
+    def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
         return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
 
-    def execute_sql(self, sql, *args):
-        def r(txn):
+    def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
+        def r(txn: LoggingTransaction) -> List[Tuple]:
             txn.execute(sql, args)
             return txn.fetchall()
 
         return self.db_pool.runInteraction("execute_sql", r)
 
-    def insert_many_txn(self, txn, table, headers, rows):
+    def insert_many_txn(
+        self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
+    ) -> None:
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
             table,
             ", ".join(k for k in headers),
@@ -218,14 +238,15 @@ class Store(
             logger.exception("Failed to insert: %s", table)
             raise
 
-    def set_room_is_public(self, room_id, is_public):
+    # Note: the parent method is an `async def`.
+    def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
         raise Exception(
             "Attempt to set room_is_public during port_db: database not empty?"
         )
 
 
 class MockHomeserver:
-    def __init__(self, config):
+    def __init__(self, config: HomeServerConfig):
         self.clock = Clock(reactor)
         self.config = config
         self.hostname = config.server.server_name
@@ -233,24 +254,30 @@ class MockHomeserver:
             "matrix-synapse"
         )
 
-    def get_clock(self):
+    def get_clock(self) -> Clock:
         return self.clock
 
-    def get_reactor(self):
+    def get_reactor(self) -> ISynapseReactor:
         return reactor
 
-    def get_instance_name(self):
+    def get_instance_name(self) -> str:
         return "master"
 
 
 class Porter:
-    def __init__(self, sqlite_config, progress, batch_size, hs_config):
+    def __init__(
+        self,
+        sqlite_config: Dict[str, Any],
+        progress: "Progress",
+        batch_size: int,
+        hs_config: HomeServerConfig,
+    ):
         self.sqlite_config = sqlite_config
         self.progress = progress
         self.batch_size = batch_size
         self.hs_config = hs_config
 
-    async def setup_table(self, table):
+    async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
             row = await self.postgres_store.db_pool.simple_select_one(
@@ -292,7 +319,7 @@ class Porter:
                 )
         else:
 
-            def delete_all(txn):
+            def delete_all(txn: LoggingTransaction) -> None:
                 txn.execute(
                     "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
                 )
@@ -317,7 +344,7 @@ class Porter:
     async def get_table_constraints(self) -> Dict[str, Set[str]]:
         """Returns a map of tables that have foreign key constraints to tables they depend on."""
 
-        def _get_constraints(txn):
+        def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
             # We can pull the information about foreign key constraints out from
             # the postgres schema tables.
             sql = """
@@ -343,8 +370,13 @@ class Porter:
         )
 
     async def handle_table(
-        self, table, postgres_size, table_size, forward_chunk, backward_chunk
-    ):
+        self,
+        table: str,
+        postgres_size: int,
+        table_size: int,
+        forward_chunk: int,
+        backward_chunk: int,
+    ) -> None:
         logger.info(
             "Table %s: %i/%i (rows %i-%i) already ported",
             table,
@@ -391,7 +423,9 @@ class Porter:
 
         while True:
 
-            def r(txn):
+            def r(
+                txn: LoggingTransaction,
+            ) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
                 forward_rows = []
                 backward_rows = []
                 if do_forward[0]:
@@ -418,6 +452,7 @@ class Porter:
             )
 
             if frows or brows:
+                assert headers is not None
                 if frows:
                     forward_chunk = max(row[0] for row in frows) + 1
                 if brows:
@@ -426,7 +461,8 @@ class Porter:
                 rows = frows + brows
                 rows = self._convert_rows(table, headers, rows)
 
-                def insert(txn):
+                def insert(txn: LoggingTransaction) -> None:
+                    assert headers is not None
                     self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
 
                     self.postgres_store.db_pool.simple_update_one_txn(
@@ -448,8 +484,12 @@ class Porter:
                 return
 
     async def handle_search_table(
-        self, postgres_size, table_size, forward_chunk, backward_chunk
-    ):
+        self,
+        postgres_size: int,
+        table_size: int,
+        forward_chunk: int,
+        backward_chunk: int,
+    ) -> None:
         select = (
             "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
             " FROM event_search as es"
@@ -460,7 +500,7 @@ class Porter:
 
         while True:
 
-            def r(txn):
+            def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
                 txn.execute(select, (forward_chunk, self.batch_size))
                 rows = txn.fetchall()
                 headers = [column[0] for column in txn.description]
@@ -474,7 +514,7 @@ class Porter:
 
                 # We have to treat event_search differently since it has a
                 # different structure in the two different databases.
-                def insert(txn):
+                def insert(txn: LoggingTransaction) -> None:
                     sql = (
                         "INSERT INTO event_search (event_id, room_id, key,"
                         " sender, vector, origin_server_ts, stream_ordering)"
@@ -528,7 +568,7 @@ class Porter:
         self,
         db_config: DatabaseConnectionConfig,
         allow_outdated_version: bool = False,
-    ):
+    ) -> Store:
         """Builds and returns a database store using the provided configuration.
 
         Args:
@@ -556,7 +596,7 @@ class Porter:
 
         return store
 
-    async def run_background_updates_on_postgres(self):
+    async def run_background_updates_on_postgres(self) -> None:
         # Manually apply all background updates on the PostgreSQL database.
         postgres_ready = (
             await self.postgres_store.db_pool.updates.has_completed_background_updates()
@@ -568,12 +608,12 @@ class Porter:
             self.progress.set_state("Running background updates on PostgreSQL")
 
         while not postgres_ready:
-            await self.postgres_store.db_pool.updates.do_next_background_update(100)
+            await self.postgres_store.db_pool.updates.do_next_background_update(True)
             postgres_ready = await (
                 self.postgres_store.db_pool.updates.has_completed_background_updates()
             )
 
-    async def run(self):
+    async def run(self) -> None:
         """Ports the SQLite database to a PostgreSQL database.
 
         When a fatal error is met, its message is assigned to the global "end_error"
@@ -609,7 +649,7 @@ class Porter:
 
             self.progress.set_state("Creating port tables")
 
-            def create_port_table(txn):
+            def create_port_table(txn: LoggingTransaction) -> None:
                 txn.execute(
                     "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
                     " table_name varchar(100) NOT NULL UNIQUE,"
@@ -622,7 +662,7 @@ class Porter:
             # We want people to be able to rerun this script from an old port
             # so that they can pick up any missing events that were not
             # ported across.
-            def alter_table(txn):
+            def alter_table(txn: LoggingTransaction) -> None:
                 txn.execute(
                     "ALTER TABLE IF EXISTS port_from_sqlite3"
                     " RENAME rowid TO forward_rowid"
@@ -742,7 +782,9 @@ class Porter:
         finally:
             reactor.stop()
 
-    def _convert_rows(self, table, headers, rows):
+    def _convert_rows(
+        self, table: str, headers: List[str], rows: List[Tuple]
+    ) -> List[Tuple]:
         bool_col_names = BOOLEAN_COLUMNS.get(table, [])
 
         bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
@@ -750,7 +792,7 @@ class Porter:
         class BadValueException(Exception):
             pass
 
-        def conv(j, col):
+        def conv(j: int, col: object) -> object:
             if j in bool_cols:
                 return bool(col)
             if isinstance(col, bytes):
@@ -776,7 +818,7 @@ class Porter:
 
         return outrows
 
-    async def _setup_sent_transactions(self):
+    async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
         # Only save things from the last day
         yesterday = int(time.time() * 1000) - 86400000
 
@@ -788,10 +830,10 @@ class Porter:
             ")"
         )
 
-        def r(txn):
+        def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
             txn.execute(select)
             rows = txn.fetchall()
-            headers = [column[0] for column in txn.description]
+            headers: List[str] = [column[0] for column in txn.description]
 
             ts_ind = headers.index("ts")
 
@@ -805,7 +847,7 @@ class Porter:
         if inserted_rows:
             max_inserted_rowid = max(r[0] for r in rows)
 
-            def insert(txn):
+            def insert(txn: LoggingTransaction) -> None:
                 self.postgres_store.insert_many_txn(
                     txn, "sent_transactions", headers[1:], rows
                 )
@@ -814,7 +856,7 @@ class Porter:
         else:
             max_inserted_rowid = 0
 
-        def get_start_id(txn):
+        def get_start_id(txn: LoggingTransaction) -> int:
             txn.execute(
                 "SELECT rowid FROM sent_transactions WHERE ts >= ?"
                 " ORDER BY rowid ASC LIMIT 1",
@@ -839,12 +881,13 @@ class Porter:
             },
         )
 
-        def get_sent_table_size(txn):
+        def get_sent_table_size(txn: LoggingTransaction) -> int:
             txn.execute(
                 "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
             )
-            (size,) = txn.fetchone()
-            return int(size)
+            result = txn.fetchone()
+            assert result is not None
+            return int(result[0])
 
         remaining_count = await self.sqlite_store.execute(get_sent_table_size)
 
@@ -852,25 +895,35 @@ class Porter:
 
         return next_chunk, inserted_rows, total_count
 
-    async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
-        frows = await self.sqlite_store.execute_sql(
-            "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
+    async def _get_remaining_count_to_port(
+        self, table: str, forward_chunk: int, backward_chunk: int
+    ) -> int:
+        frows = cast(
+            List[Tuple[int]],
+            await self.sqlite_store.execute_sql(
+                "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
+            ),
         )
 
-        brows = await self.sqlite_store.execute_sql(
-            "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
+        brows = cast(
+            List[Tuple[int]],
+            await self.sqlite_store.execute_sql(
+                "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
+            ),
         )
 
         return frows[0][0] + brows[0][0]
 
-    async def _get_already_ported_count(self, table):
+    async def _get_already_ported_count(self, table: str) -> int:
         rows = await self.postgres_store.execute_sql(
             "SELECT count(*) FROM %s" % (table,)
         )
 
         return rows[0][0]
 
-    async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
+    async def _get_total_count_to_port(
+        self, table: str, forward_chunk: int, backward_chunk: int
+    ) -> Tuple[int, int]:
         remaining, done = await make_deferred_yieldable(
             defer.gatherResults(
                 [
@@ -891,14 +944,17 @@ class Porter:
         return done, remaining + done
 
     async def _setup_state_group_id_seq(self) -> None:
-        curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+        curr_id: Optional[
+            int
+        ] = await self.sqlite_store.db_pool.simple_select_one_onecol(
             table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
         )
 
         if not curr_id:
             return
 
-        def r(txn):
+        def r(txn: LoggingTransaction) -> None:
+            assert curr_id is not None
             next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 
@@ -909,7 +965,7 @@ class Porter:
             "setup_user_id_seq", find_max_generated_user_id_localpart
         )
 
-        def r(txn):
+        def r(txn: LoggingTransaction) -> None:
             next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
 
@@ -931,7 +987,7 @@ class Porter:
             allow_none=True,
         )
 
-        def _setup_events_stream_seqs_set_pos(txn):
+        def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction) -> None:
             if curr_forward_id:
                 txn.execute(
                     "ALTER SEQUENCE events_stream_seq RESTART WITH %s",
@@ -955,17 +1011,20 @@ class Porter:
         """Set a sequence to the correct value."""
         current_stream_ids = []
         for stream_id_table in stream_id_tables:
-            max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
-                table=stream_id_table,
-                keyvalues={},
-                retcol="COALESCE(MAX(stream_id), 1)",
-                allow_none=True,
+            max_stream_id = cast(
+                int,
+                await self.sqlite_store.db_pool.simple_select_one_onecol(
+                    table=stream_id_table,
+                    keyvalues={},
+                    retcol="COALESCE(MAX(stream_id), 1)",
+                    allow_none=True,
+                ),
             )
             current_stream_ids.append(max_stream_id)
 
         next_id = max(current_stream_ids) + 1
 
-        def r(txn):
+        def r(txn: LoggingTransaction) -> None:
             sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
             txn.execute(sql + " %s", (next_id,))
 
@@ -974,14 +1033,18 @@ class Porter:
         )
 
     async def _setup_auth_chain_sequence(self) -> None:
-        curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+        curr_chain_id: Optional[
+            int
+        ] = await self.sqlite_store.db_pool.simple_select_one_onecol(
             table="event_auth_chains",
             keyvalues={},
             retcol="MAX(chain_id)",
             allow_none=True,
         )
 
-        def r(txn):
+        def r(txn: LoggingTransaction) -> None:
+            # Presumably there is at least one row in event_auth_chains.
+            assert curr_chain_id is not None
             txn.execute(
                 "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
                 (curr_chain_id + 1,),
@@ -999,15 +1062,22 @@ class Porter:
 ##############################################
 
 
-class Progress(object):
+class TableProgress(TypedDict):
+    start: int
+    num_done: int
+    total: int
+    perc: int
+
+
+class Progress:
     """Used to report progress of the port"""
 
-    def __init__(self):
-        self.tables = {}
+    def __init__(self) -> None:
+        self.tables: Dict[str, TableProgress] = {}
 
         self.start_time = int(time.time())
 
-    def add_table(self, table, cur, size):
+    def add_table(self, table: str, cur: int, size: int) -> None:
         self.tables[table] = {
             "start": cur,
             "num_done": cur,
@@ -1015,19 +1085,22 @@ class Progress(object):
             "perc": int(cur * 100 / size),
         }
 
-    def update(self, table, num_done):
+    def update(self, table: str, num_done: int) -> None:
         data = self.tables[table]
         data["num_done"] = num_done
         data["perc"] = int(num_done * 100 / data["total"])
 
-    def done(self):
+    def done(self) -> None:
+        pass
+
+    def set_state(self, state: str) -> None:
         pass
 
 
 class CursesProgress(Progress):
     """Reports progress to a curses window"""
 
-    def __init__(self, stdscr):
+    def __init__(self, stdscr: "curses.window"):
         self.stdscr = stdscr
 
         curses.use_default_colors()
@@ -1045,7 +1118,7 @@ class CursesProgress(Progress):
 
         super(CursesProgress, self).__init__()
 
-    def update(self, table, num_done):
+    def update(self, table: str, num_done: int) -> None:
         super(CursesProgress, self).update(table, num_done)
 
         self.total_processed = 0
@@ -1056,7 +1129,7 @@ class CursesProgress(Progress):
 
         self.render()
 
-    def render(self, force=False):
+    def render(self, force: bool = False) -> None:
         now = time.time()
 
         if not force and now - self.last_update < 0.2:
@@ -1128,12 +1201,12 @@ class CursesProgress(Progress):
         self.stdscr.refresh()
         self.last_update = time.time()
 
-    def done(self):
+    def done(self) -> None:
         self.finished = True
         self.render(True)
         self.stdscr.getch()
 
-    def set_state(self, state):
+    def set_state(self, state: str) -> None:
         self.stdscr.clear()
         self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
         self.stdscr.refresh()
@@ -1142,7 +1215,7 @@ class CursesProgress(Progress):
 class TerminalProgress(Progress):
     """Just prints progress to the terminal"""
 
-    def update(self, table, num_done):
+    def update(self, table: str, num_done: int) -> None:
         super(TerminalProgress, self).update(table, num_done)
 
         data = self.tables[table]
@@ -1151,7 +1224,7 @@ class TerminalProgress(Progress):
             "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
         )
 
-    def set_state(self, state):
+    def set_state(self, state: str) -> None:
         print(state + "...")
 
 
@@ -1159,7 +1232,7 @@ class TerminalProgress(Progress):
 ##############################################
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
         description="A script to port an existing synapse SQLite database to"
         " a new PostgreSQL database."
@@ -1225,7 +1298,7 @@ def main():
     config = HomeServerConfig()
     config.parse_config_dict(hs_config, "", "")
 
-    def start(stdscr=None):
+    def start(stdscr: Optional["curses.window"] = None) -> None:
         progress: Progress
         if stdscr:
             progress = CursesProgress(stdscr)
@@ -1240,7 +1313,7 @@ def main():
         )
 
         @defer.inlineCallbacks
-        def run():
+        def run() -> Generator["defer.Deferred[Any]", Any, None]:
             with LoggingContext("synapse_port_db_run"):
                 yield defer.ensureDeferred(porter.run())
 
diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py
index 1ab36949c7..b4c96ad7f3 100755
--- a/synapse/_scripts/synctl.py
+++ b/synapse/_scripts/synctl.py
@@ -24,7 +24,7 @@ import signal
 import subprocess
 import sys
 import time
-from typing import Iterable, Optional
+from typing import Iterable, NoReturn, Optional, TextIO
 
 import yaml
 
@@ -45,7 +45,7 @@ one of the following:
 --------------------------------------------------------------------------------"""
 
 
-def pid_running(pid):
+def pid_running(pid: int) -> bool:
     try:
         os.kill(pid, 0)
     except OSError as err:
@@ -68,7 +68,7 @@ def pid_running(pid):
     return True
 
 
-def write(message, colour=NORMAL, stream=sys.stdout):
+def write(message: str, colour: str = NORMAL, stream: TextIO = sys.stdout) -> None:
     # Lets check if we're writing to a TTY before colouring
     should_colour = False
     try:
@@ -84,7 +84,7 @@ def write(message, colour=NORMAL, stream=sys.stdout):
         stream.write(colour + message + NORMAL + "\n")
 
 
-def abort(message, colour=RED, stream=sys.stderr):
+def abort(message: str, colour: str = RED, stream: TextIO = sys.stderr) -> NoReturn:
     write(message, colour, stream)
     sys.exit(1)
 
@@ -166,7 +166,7 @@ Worker = collections.namedtuple(
 )
 
 
-def main():
+def main() -> None:
 
     parser = argparse.ArgumentParser()
 
diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py
index 736f58836d..c443522c05 100755
--- a/synapse/_scripts/update_synapse_database.py
+++ b/synapse/_scripts/update_synapse_database.py
@@ -38,25 +38,25 @@ logger = logging.getLogger("update_database")
 class MockHomeserver(HomeServer):
     DATASTORE_CLASS = DataStore  # type: ignore [assignment]
 
-    def __init__(self, config, **kwargs):
+    def __init__(self, config: HomeServerConfig):
         super(MockHomeserver, self).__init__(
-            config.server.server_name, reactor=reactor, config=config, **kwargs
-        )
-
-        self.version_string = "Synapse/" + get_distribution_version_string(
-            "matrix-synapse"
+            hostname=config.server.server_name,
+            config=config,
+            reactor=reactor,
+            version_string="Synapse/"
+            + get_distribution_version_string("matrix-synapse"),
         )
 
 
-def run_background_updates(hs):
+def run_background_updates(hs: HomeServer) -> None:
     store = hs.get_datastores().main
 
-    async def run_background_updates():
+    async def run_background_updates() -> None:
         await store.db_pool.updates.run_background_updates(sleep=False)
         # Stop the reactor to exit the script once every background update is run.
         reactor.stop()
 
-    def run():
+    def run() -> None:
         # Apply all background updates on the database.
         defer.ensureDeferred(
             run_as_background_process("background_updates", run_background_updates)
@@ -67,7 +67,7 @@ def run_background_updates(hs):
     reactor.run()
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
         description=(
             "Updates a synapse database to the latest schema and optionally runs background updates"
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 57f4883bf4..d7d6f1d90e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -45,6 +45,7 @@ class Cursor(Protocol):
         Sequence[
             # Note that this is an approximate typing based on sqlite3 and other
             # drivers, and may not be entirely accurate.
+            # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
             Tuple[
                 str,
                 Optional[Any],
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 6f3c365c9a..19a145eeb6 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
 
 from synapse._scripts.register_new_matrix_user import request_registration
 
@@ -52,16 +52,16 @@ class RegisterTestCase(TestCase):
         out = []
         err_code = []
 
-        request_registration(
-            "user",
-            "pass",
-            "matrix.org",
-            "shared",
-            admin=False,
-            requests=requests,
-            _print=out.append,
-            exit=err_code.append,
-        )
+        with patch("synapse._scripts.register_new_matrix_user.requests", requests):
+            request_registration(
+                "user",
+                "pass",
+                "matrix.org",
+                "shared",
+                admin=False,
+                _print=out.append,
+                exit=err_code.append,
+            )
 
         # We should get the success message making sure everything is OK.
         self.assertIn("Success!", out)
@@ -88,16 +88,16 @@ class RegisterTestCase(TestCase):
         out = []
         err_code = []
 
-        request_registration(
-            "user",
-            "pass",
-            "matrix.org",
-            "shared",
-            admin=False,
-            requests=requests,
-            _print=out.append,
-            exit=err_code.append,
-        )
+        with patch("synapse._scripts.register_new_matrix_user.requests", requests):
+            request_registration(
+                "user",
+                "pass",
+                "matrix.org",
+                "shared",
+                admin=False,
+                _print=out.append,
+                exit=err_code.append,
+            )
 
         # Exit was called
         self.assertEqual(err_code, [1])
@@ -140,16 +140,16 @@ class RegisterTestCase(TestCase):
         out = []
         err_code = []
 
-        request_registration(
-            "user",
-            "pass",
-            "matrix.org",
-            "shared",
-            admin=False,
-            requests=requests,
-            _print=out.append,
-            exit=err_code.append,
-        )
+        with patch("synapse._scripts.register_new_matrix_user.requests", requests):
+            request_registration(
+                "user",
+                "pass",
+                "matrix.org",
+                "shared",
+                admin=False,
+                _print=out.append,
+                exit=err_code.append,
+            )
 
         # Exit was called
         self.assertEqual(err_code, [1])