summary refs log tree commit diff
path: root/synapse/_scripts/synapse_port_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/_scripts/synapse_port_db.py')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py53
1 files changed, 31 insertions, 22 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py

index 6324df883b..123eaae5c5 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -21,12 +21,13 @@ import logging import sys import time import traceback -from typing import Dict, Iterable, Optional, Set +from types import TracebackType +from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast import yaml from matrix_common.versionstring import get_distribution_version_string -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as reactor_ from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig @@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import ( from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +from synapse.types import ISynapseReactor from synapse.util import Clock +# Cast safety: Twisted does some naughty magic which replaces the +# twisted.internet.reactor module with a Reactor instance at runtime. +reactor = cast(ISynapseReactor, reactor_) logger = logging.getLogger("synapse_port_db") @@ -159,12 +164,14 @@ IGNORED_TABLES = { # Error returned by the run function. Used at the top-level part of the script to # handle errors and return codes. -end_error = None # type: Optional[str] +end_error: Optional[str] = None # The exec_info for the error, if any. If error is defined but not exec_info the script # will show only the error message without the stacktrace, if exec_info is defined but # not the error then the script will show nothing outside of what's printed in the run # function. If both are defined, the script will print both the error and the stacktrace. -end_error_exec_info = None +end_error_exec_info: Optional[ + Tuple[Type[BaseException], BaseException, TracebackType] +] = None class Store( @@ -236,9 +243,12 @@ class MockHomeserver: return "master" -class Porter(object): - def __init__(self, **kwargs): - self.__dict__.update(kwargs) +class Porter: + def __init__(self, sqlite_config, progress, batch_size, hs_config): + self.sqlite_config = sqlite_config + self.progress = progress + self.batch_size = batch_size + self.hs_config = hs_config async def setup_table(self, table): if table in APPEND_ONLY_TABLES: @@ -323,7 +333,7 @@ class Porter(object): """ txn.execute(sql) - results = {} + results: Dict[str, Set[str]] = {} for table, foreign_table in txn: results.setdefault(table, set()).add(foreign_table) return results @@ -540,7 +550,8 @@ class Porter(object): db_conn, allow_outdated_version=allow_outdated_version ) prepare_database(db_conn, engine, config=self.hs_config) - store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) + # Type safety: ignore that we're using Mock homeservers here. + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type] db_conn.commit() return store @@ -724,7 +735,9 @@ class Porter(object): except Exception as e: global end_error_exec_info end_error = str(e) - end_error_exec_info = sys.exc_info() + # Type safety: we're in an exception handler, so the exc_info() tuple + # will not be (None, None, None). + end_error_exec_info = sys.exc_info() # type: ignore[assignment] logger.exception("") finally: reactor.stop() @@ -1023,7 +1036,7 @@ class CursesProgress(Progress): curses.init_pair(1, curses.COLOR_RED, -1) curses.init_pair(2, curses.COLOR_GREEN, -1) - self.last_update = 0 + self.last_update = 0.0 self.finished = False @@ -1082,8 +1095,7 @@ class CursesProgress(Progress): left_margin = 5 middle_space = 1 - items = self.tables.items() - items = sorted(items, key=lambda i: (i[1]["perc"], i[0])) + items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0])) for i, (table, data) in enumerate(items): if i + 2 >= rows: @@ -1179,15 +1191,11 @@ def main(): args = parser.parse_args() - logging_config = { - "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", - } - - if args.curses: - logging_config["filename"] = "port-synapse.log" - - logging.basicConfig(**logging_config) + logging.basicConfig( + level=logging.DEBUG if args.v else logging.INFO, + format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", + filename="port-synapse.log" if args.curses else None, + ) sqlite_config = { "name": "sqlite3", @@ -1218,6 +1226,7 @@ def main(): config.parse_config_dict(hs_config, "", "") def start(stdscr=None): + progress: Progress if stdscr: progress = CursesProgress(stdscr) else: