summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/export_signing_key.py6
-rwxr-xr-xsynapse/_scripts/move_remote_media_to_new_store.py9
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py53
-rwxr-xr-xsynapse/_scripts/update_synapse_database.py19
4 files changed, 49 insertions, 38 deletions
diff --git a/synapse/_scripts/export_signing_key.py b/synapse/_scripts/export_signing_key.py
index 3d254348f1..66481533e9 100755
--- a/synapse/_scripts/export_signing_key.py
+++ b/synapse/_scripts/export_signing_key.py
@@ -17,8 +17,8 @@ import sys
 import time
 from typing import Optional
 
-import nacl.signing
 from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
+from signedjson.types import VerifyKey
 
 
 def exit(status: int = 0, message: Optional[str] = None):
@@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None):
     sys.exit(status)
 
 
-def format_plain(public_key: nacl.signing.VerifyKey):
+def format_plain(public_key: VerifyKey):
     print(
         "%s:%s %s"
         % (
@@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
     )
 
 
-def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
+def format_for_config(public_key: VerifyKey, expiry_ts: int):
     print(
         '  "%s:%s": { key: "%s", expired_ts: %i }'
         % (
diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py
index 9667d95dfe..f53bf790af 100755
--- a/synapse/_scripts/move_remote_media_to_new_store.py
+++ b/synapse/_scripts/move_remote_media_to_new_store.py
@@ -109,10 +109,9 @@ if __name__ == "__main__":
     parser.add_argument("dest_repo", help="Path to source content repo")
     args = parser.parse_args()
 
-    logging_config = {
-        "level": logging.DEBUG if args.v else logging.INFO,
-        "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
-    }
-    logging.basicConfig(**logging_config)
+    logging.basicConfig(
+        level=logging.DEBUG if args.v else logging.INFO,
+        format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+    )
 
     main(args.src_repo, args.dest_repo)
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 6324df883b..123eaae5c5 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -21,12 +21,13 @@ import logging
 import sys
 import time
 import traceback
-from typing import Dict, Iterable, Optional, Set
+from types import TracebackType
+from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
 
 import yaml
 from matrix_common.versionstring import get_distribution_version_string
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as reactor_
 
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
@@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
 from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
+from synapse.types import ISynapseReactor
 from synapse.util import Clock
 
+# Cast safety: Twisted does some naughty magic which replaces the
+# twisted.internet.reactor module with a Reactor instance at runtime.
+reactor = cast(ISynapseReactor, reactor_)
 logger = logging.getLogger("synapse_port_db")
 
 
@@ -159,12 +164,14 @@ IGNORED_TABLES = {
 
 # Error returned by the run function. Used at the top-level part of the script to
 # handle errors and return codes.
-end_error = None  # type: Optional[str]
+end_error: Optional[str] = None
 # The exec_info for the error, if any. If error is defined but not exec_info the script
 # will show only the error message without the stacktrace, if exec_info is defined but
 # not the error then the script will show nothing outside of what's printed in the run
 # function. If both are defined, the script will print both the error and the stacktrace.
-end_error_exec_info = None
+end_error_exec_info: Optional[
+    Tuple[Type[BaseException], BaseException, TracebackType]
+] = None
 
 
 class Store(
@@ -236,9 +243,12 @@ class MockHomeserver:
         return "master"
 
 
-class Porter(object):
-    def __init__(self, **kwargs):
-        self.__dict__.update(kwargs)
+class Porter:
+    def __init__(self, sqlite_config, progress, batch_size, hs_config):
+        self.sqlite_config = sqlite_config
+        self.progress = progress
+        self.batch_size = batch_size
+        self.hs_config = hs_config
 
     async def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
@@ -323,7 +333,7 @@ class Porter(object):
             """
             txn.execute(sql)
 
-            results = {}
+            results: Dict[str, Set[str]] = {}
             for table, foreign_table in txn:
                 results.setdefault(table, set()).add(foreign_table)
             return results
@@ -540,7 +550,8 @@ class Porter(object):
                 db_conn, allow_outdated_version=allow_outdated_version
             )
             prepare_database(db_conn, engine, config=self.hs_config)
-            store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
+            # Type safety: ignore that we're using Mock homeservers here.
+            store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)  # type: ignore[arg-type]
             db_conn.commit()
 
         return store
@@ -724,7 +735,9 @@ class Porter(object):
         except Exception as e:
             global end_error_exec_info
             end_error = str(e)
-            end_error_exec_info = sys.exc_info()
+            # Type safety: we're in an exception handler, so the exc_info() tuple
+            # will not be (None, None, None).
+            end_error_exec_info = sys.exc_info()  # type: ignore[assignment]
             logger.exception("")
         finally:
             reactor.stop()
@@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
         curses.init_pair(1, curses.COLOR_RED, -1)
         curses.init_pair(2, curses.COLOR_GREEN, -1)
 
-        self.last_update = 0
+        self.last_update = 0.0
 
         self.finished = False
 
@@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
         left_margin = 5
         middle_space = 1
 
-        items = self.tables.items()
-        items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
+        items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
 
         for i, (table, data) in enumerate(items):
             if i + 2 >= rows:
@@ -1179,15 +1191,11 @@ def main():
 
     args = parser.parse_args()
 
-    logging_config = {
-        "level": logging.DEBUG if args.v else logging.INFO,
-        "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
-    }
-
-    if args.curses:
-        logging_config["filename"] = "port-synapse.log"
-
-    logging.basicConfig(**logging_config)
+    logging.basicConfig(
+        level=logging.DEBUG if args.v else logging.INFO,
+        format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+        filename="port-synapse.log" if args.curses else None,
+    )
 
     sqlite_config = {
         "name": "sqlite3",
@@ -1218,6 +1226,7 @@ def main():
     config.parse_config_dict(hs_config, "", "")
 
     def start(stdscr=None):
+        progress: Progress
         if stdscr:
             progress = CursesProgress(stdscr)
         else:
diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py
index f43676afaa..736f58836d 100755
--- a/synapse/_scripts/update_synapse_database.py
+++ b/synapse/_scripts/update_synapse_database.py
@@ -16,22 +16,27 @@
 import argparse
 import logging
 import sys
+from typing import cast
 
 import yaml
 from matrix_common.versionstring import get_distribution_version_string
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as reactor_
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.server import HomeServer
 from synapse.storage import DataStore
+from synapse.types import ISynapseReactor
 
+# Cast safety: Twisted does some naughty magic which replaces the
+# twisted.internet.reactor module with a Reactor instance at runtime.
+reactor = cast(ISynapseReactor, reactor_)
 logger = logging.getLogger("update_database")
 
 
 class MockHomeserver(HomeServer):
-    DATASTORE_CLASS = DataStore
+    DATASTORE_CLASS = DataStore  # type: ignore [assignment]
 
     def __init__(self, config, **kwargs):
         super(MockHomeserver, self).__init__(
@@ -85,12 +90,10 @@ def main():
 
     args = parser.parse_args()
 
-    logging_config = {
-        "level": logging.DEBUG if args.v else logging.INFO,
-        "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
-    }
-
-    logging.basicConfig(**logging_config)
+    logging.basicConfig(
+        level=logging.DEBUG if args.v else logging.INFO,
+        format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
+    )
 
     # Load, process and sanity-check the config.
     hs_config = yaml.safe_load(args.database_config)