From 0cd182f296ce44dbaafc9a56f9af2183d21a9443 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 8 Apr 2022 15:00:12 +0100 Subject: Make `synapse._scripts` pass typechecks (#12421) --- synapse/_scripts/export_signing_key.py | 6 +-- synapse/_scripts/move_remote_media_to_new_store.py | 9 ++-- synapse/_scripts/synapse_port_db.py | 53 +++++++++++++--------- synapse/_scripts/update_synapse_database.py | 19 ++++---- 4 files changed, 49 insertions(+), 38 deletions(-) (limited to 'synapse') diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py index 3d254348f1..66481533e9 100755 --- a/synapse/_scripts/export_signing_key.py +++ b/synapse/_scripts/export_signing_key.py @@ -17,8 +17,8 @@ import sys import time from typing import Optional -import nacl.signing from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys +from signedjson.types import VerifyKey def exit(status: int = 0, message: Optional[str] = None): @@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None): sys.exit(status) -def format_plain(public_key: nacl.signing.VerifyKey): +def format_plain(public_key: VerifyKey): print( "%s:%s %s" % ( @@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey): ) -def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): +def format_for_config(public_key: VerifyKey, expiry_ts: int): print( ' "%s:%s": { key: "%s", expired_ts: %i }' % ( diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py index 9667d95dfe..f53bf790af 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -109,10 +109,9 @@ if __name__ == "__main__": parser.add_argument("dest_repo", help="Path to source content repo") args = parser.parse_args() - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - logging.basicConfig(**logging_config) + logging.basicConfig( + level=logging.DEBUG if args.v else logging.INFO, + format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + ) main(args.src_repo, args.dest_repo) diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 6324df883b..123eaae5c5 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -21,12 +21,13 @@ import logging import sys import time import traceback -from typing import Dict, Iterable, Optional, Set +from types import TracebackType +from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast import yaml from matrix_common.versionstring import get_distribution_version_string -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as reactor_ from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig @@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import ( from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +from synapse.types import ISynapseReactor from synapse.util import Clock +# Cast safety: Twisted does some naughty magic which replaces the +# twisted.internet.reactor module with a Reactor instance at runtime. +reactor = cast(ISynapseReactor, reactor_) logger = logging.getLogger("synapse_port_db") @@ -159,12 +164,14 @@ IGNORED_TABLES = { # Error returned by the run function. Used at the top-level part of the script to # handle errors and return codes. -end_error = None # type: Optional[str] +end_error: Optional[str] = None # The exec_info for the error, if any. If error is defined but not exec_info the script # will show only the error message without the stacktrace, if exec_info is defined but # not the error then the script will show nothing outside of what's printed in the run # function. If both are defined, the script will print both the error and the stacktrace. -end_error_exec_info = None +end_error_exec_info: Optional[ + Tuple[Type[BaseException], BaseException, TracebackType] +] = None class Store( @@ -236,9 +243,12 @@ class MockHomeserver: return "master" -class Porter(object): - def __init__(self, **kwargs): - self.__dict__.update(kwargs) +class Porter: + def __init__(self, sqlite_config, progress, batch_size, hs_config): + self.sqlite_config = sqlite_config + self.progress = progress + self.batch_size = batch_size + self.hs_config = hs_config async def setup_table(self, table): if table in APPEND_ONLY_TABLES: @@ -323,7 +333,7 @@ class Porter(object): """ txn.execute(sql) - results = {} + results: Dict[str, Set[str]] = {} for table, foreign_table in txn: results.setdefault(table, set()).add(foreign_table) return results @@ -540,7 +550,8 @@ class Porter(object): db_conn, allow_outdated_version=allow_outdated_version ) prepare_database(db_conn, engine, config=self.hs_config) - store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) + # Type safety: ignore that we're using Mock homeservers here. + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type] db_conn.commit() return store @@ -724,7 +735,9 @@ class Porter(object): except Exception as e: global end_error_exec_info end_error = str(e) - end_error_exec_info = sys.exc_info() + # Type safety: we're in an exception handler, so the exc_info() tuple + # will not be (None, None, None). + end_error_exec_info = sys.exc_info() # type: ignore[assignment] logger.exception("") finally: reactor.stop() @@ -1023,7 +1036,7 @@ class CursesProgress(Progress): curses.init_pair(1, curses.COLOR_RED, -1) curses.init_pair(2, curses.COLOR_GREEN, -1) - self.last_update = 0 + self.last_update = 0.0 self.finished = False @@ -1082,8 +1095,7 @@ class CursesProgress(Progress): left_margin = 5 middle_space = 1 - items = self.tables.items() - items = sorted(items, key=lambda i: (i[1]["perc"], i[0])) + items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0])) for i, (table, data) in enumerate(items): if i + 2 >= rows: @@ -1179,15 +1191,11 @@ def main(): args = parser.parse_args() - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - - if args.curses: - logging_config["filename"] = "port-synapse.log" - - logging.basicConfig(**logging_config) + logging.basicConfig( + level=logging.DEBUG if args.v else logging.INFO, + format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + filename="port-synapse.log" if args.curses else None, + ) sqlite_config = { "name": "sqlite3", @@ -1218,6 +1226,7 @@ def main(): config.parse_config_dict(hs_config, "", "") def start(stdscr=None): + progress: Progress if stdscr: progress = CursesProgress(stdscr) else: diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index f43676afaa..736f58836d 100755 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -16,22 +16,27 @@ import argparse import logging import sys +from typing import cast import yaml from matrix_common.versionstring import get_distribution_version_string -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as reactor_ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.types import ISynapseReactor +# Cast safety: Twisted does some naughty magic which replaces the +# twisted.internet.reactor module with a Reactor instance at runtime. +reactor = cast(ISynapseReactor, reactor_) logger = logging.getLogger("update_database") class MockHomeserver(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore [assignment] def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( @@ -85,12 +90,10 @@ def main(): args = parser.parse_args() - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - - logging.basicConfig(**logging_config) + logging.basicConfig( + level=logging.DEBUG if args.v else logging.INFO, + format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + ) # Load, process and sanity-check the config. hs_config = yaml.safe_load(args.database_config) -- cgit 1.4.1