diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 6324df883b..123eaae5c5 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -21,12 +21,13 @@ import logging
import sys
import time
import traceback
-from typing import Dict, Iterable, Optional, Set
+from types import TracebackType
+from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
import yaml
from matrix_common.versionstring import get_distribution_version_string
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+from synapse.types import ISynapseReactor
from synapse.util import Clock
+# Cast safety: Twisted does some naughty magic which replaces the
+# twisted.internet.reactor module with a Reactor instance at runtime.
+reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")
@@ -159,12 +164,14 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
-end_error = None # type: Optional[str]
+end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
-end_error_exec_info = None
+end_error_exec_info: Optional[
+ Tuple[Type[BaseException], BaseException, TracebackType]
+] = None
class Store(
@@ -236,9 +243,12 @@ class MockHomeserver:
return "master"
-class Porter(object):
- def __init__(self, **kwargs):
- self.__dict__.update(kwargs)
+class Porter:
+ def __init__(self, sqlite_config, progress, batch_size, hs_config):
+ self.sqlite_config = sqlite_config
+ self.progress = progress
+ self.batch_size = batch_size
+ self.hs_config = hs_config
async def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
@@ -323,7 +333,7 @@ class Porter(object):
"""
txn.execute(sql)
- results = {}
+ results: Dict[str, Set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
@@ -540,7 +550,8 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
- store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
+ # Type safety: ignore that we're using Mock homeservers here.
+ store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit()
return store
@@ -724,7 +735,9 @@ class Porter(object):
except Exception as e:
global end_error_exec_info
end_error = str(e)
- end_error_exec_info = sys.exc_info()
+ # Type safety: we're in an exception handler, so the exc_info() tuple
+ # will not be (None, None, None).
+ end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("")
finally:
reactor.stop()
@@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
- self.last_update = 0
+ self.last_update = 0.0
self.finished = False
@@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
left_margin = 5
middle_space = 1
- items = self.tables.items()
- items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
+ items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@@ -1179,15 +1191,11 @@ def main():
args = parser.parse_args()
- logging_config = {
- "level": logging.DEBUG if args.v else logging.INFO,
- "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
- }
-
- if args.curses:
- logging_config["filename"] = "port-synapse.log"
-
- logging.basicConfig(**logging_config)
+ logging.basicConfig(
+ level=logging.DEBUG if args.v else logging.INFO,
+ format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+ filename="port-synapse.log" if args.curses else None,
+ )
sqlite_config = {
"name": "sqlite3",
@@ -1218,6 +1226,7 @@ def main():
config.parse_config_dict(hs_config, "", "")
def start(stdscr=None):
+ progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
else:
|