summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-05-14 11:46:35 +0100
committerGitHub <noreply@github.com>2021-05-14 11:46:35 +0100
commit6482075c95957ad980d9c1323f9f982e6f7aaff4 (patch)
tree593a4e2e96f27d63be8265d1a20228a2271b1be4 /scripts
parentMinor `@cachedList` enhancements (#9975) (diff)
downloadsynapse-6482075c95957ad980d9c1323f9f982e6f7aaff4.tar.xz
Run `black` on the scripts (#9981)
Turns out these scripts weren't getting linted.
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/export_signing_key13
-rwxr-xr-xscripts/generate_config18
-rwxr-xr-xscripts/hash_password6
-rwxr-xr-xscripts/synapse_port_db46
4 files changed, 47 insertions, 36 deletions
diff --git a/scripts/export_signing_key b/scripts/export_signing_key
index 0ed167ea85..bf0139bd64 100755
--- a/scripts/export_signing_key
+++ b/scripts/export_signing_key
@@ -30,7 +30,11 @@ def exit(status: int = 0, message: Optional[str] = None):
 def format_plain(public_key: nacl.signing.VerifyKey):
     print(
         "%s:%s %s"
-        % (public_key.alg, public_key.version, encode_verify_key_base64(public_key),)
+        % (
+            public_key.alg,
+            public_key.version,
+            encode_verify_key_base64(public_key),
+        )
     )
 
 
@@ -50,7 +54,10 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
-        "key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read",
+        "key_file",
+        nargs="+",
+        type=argparse.FileType("r"),
+        help="The key file to read",
     )
 
     parser.add_argument(
@@ -63,7 +70,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--expiry-ts",
         type=int,
-        default=int(time.time() * 1000) + 6*3600000,
+        default=int(time.time() * 1000) + 6 * 3600000,
         help=(
             "The expiry time to use for -x, in milliseconds since 1970. The default "
             "is (now+6h)."
diff --git a/scripts/generate_config b/scripts/generate_config
index 771cbf8d95..931b40c045 100755
--- a/scripts/generate_config
+++ b/scripts/generate_config
@@ -11,23 +11,22 @@ if __name__ == "__main__":
     parser.add_argument(
         "--config-dir",
         default="CONFDIR",
-
         help="The path where the config files are kept. Used to create filenames for "
-             "things like the log config and the signing key. Default: %(default)s",
+        "things like the log config and the signing key. Default: %(default)s",
     )
 
     parser.add_argument(
         "--data-dir",
         default="DATADIR",
         help="The path where the data files are kept. Used to create filenames for "
-             "things like the database and media store. Default: %(default)s",
+        "things like the database and media store. Default: %(default)s",
     )
 
     parser.add_argument(
         "--server-name",
         default="SERVERNAME",
         help="The server name. Used to initialise the server_name config param, but also "
-             "used in the names of some of the config files. Default: %(default)s",
+        "used in the names of some of the config files. Default: %(default)s",
     )
 
     parser.add_argument(
@@ -41,21 +40,22 @@ if __name__ == "__main__":
         "--generate-secrets",
         action="store_true",
         help="Enable generation of new secrets for things like the macaroon_secret_key."
-             "By default, these parameters will be left unset."
+        "By default, these parameters will be left unset.",
     )
 
     parser.add_argument(
-        "-o", "--output-file",
-        type=argparse.FileType('w'),
+        "-o",
+        "--output-file",
+        type=argparse.FileType("w"),
         default=sys.stdout,
         help="File to write the configuration to. Default: stdout",
     )
 
     parser.add_argument(
         "--header-file",
-        type=argparse.FileType('r'),
+        type=argparse.FileType("r"),
         help="File from which to read a header, which will be printed before the "
-             "generated config.",
+        "generated config.",
     )
 
     args = parser.parse_args()
diff --git a/scripts/hash_password b/scripts/hash_password
index a30767f758..1d6fb0d700 100755
--- a/scripts/hash_password
+++ b/scripts/hash_password
@@ -41,7 +41,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "-c",
         "--config",
-        type=argparse.FileType('r'),
+        type=argparse.FileType("r"),
         help=(
             "Path to server config file. "
             "Used to read in bcrypt_rounds and password_pepper."
@@ -72,8 +72,8 @@ if __name__ == "__main__":
     pw = unicodedata.normalize("NFKC", password)
 
     hashed = bcrypt.hashpw(
-        pw.encode('utf8') + password_pepper.encode("utf8"),
+        pw.encode("utf8") + password_pepper.encode("utf8"),
         bcrypt.gensalt(bcrypt_rounds),
-    ).decode('ascii')
+    ).decode("ascii")
 
     print(hashed)
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 5fb5bb35f7..7c7645c05a 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -294,8 +294,7 @@ class Porter(object):
         return table, already_ported, total_to_port, forward_chunk, backward_chunk
 
     async def get_table_constraints(self) -> Dict[str, Set[str]]:
-        """Returns a map of tables that have foreign key constraints to tables they depend on.
-        """
+        """Returns a map of tables that have foreign key constraints to tables they depend on."""
 
         def _get_constraints(txn):
             # We can pull the information about foreign key constraints out from
@@ -504,7 +503,9 @@ class Porter(object):
                 return
 
     def build_db_store(
-        self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
+        self,
+        db_config: DatabaseConnectionConfig,
+        allow_outdated_version: bool = False,
     ):
         """Builds and returns a database store using the provided configuration.
 
@@ -740,7 +741,7 @@ class Porter(object):
             return col
 
         outrows = []
-        for i, row in enumerate(rows):
+        for row in rows:
             try:
                 outrows.append(
                     tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
@@ -890,8 +891,7 @@ class Porter(object):
         await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
 
     async def _setup_events_stream_seqs(self) -> None:
-        """Set the event stream sequences to the correct values.
-        """
+        """Set the event stream sequences to the correct values."""
 
         # We get called before we've ported the events table, so we need to
         # fetch the current positions from the SQLite store.
@@ -920,12 +920,14 @@ class Porter(object):
                 )
 
         await self.postgres_store.db_pool.runInteraction(
-            "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
+            "_setup_events_stream_seqs",
+            _setup_events_stream_seqs_set_pos,
         )
 
-    async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None:
-        """Set a sequence to the correct value.
-        """
+    async def _setup_sequence(
+        self, sequence_name: str, stream_id_tables: Iterable[str]
+    ) -> None:
+        """Set a sequence to the correct value."""
         current_stream_ids = []
         for stream_id_table in stream_id_tables:
             max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
@@ -939,14 +941,19 @@ class Porter(object):
         next_id = max(current_stream_ids) + 1
 
         def r(txn):
-            sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, )
-            txn.execute(sql + " %s", (next_id, ))
+            sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
+            txn.execute(sql + " %s", (next_id,))
 
-        await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r)
+        await self.postgres_store.db_pool.runInteraction(
+            "_setup_%s" % (sequence_name,), r
+        )
 
     async def _setup_auth_chain_sequence(self) -> None:
         curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
-            table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True
+            table="event_auth_chains",
+            keyvalues={},
+            retcol="MAX(chain_id)",
+            allow_none=True,
         )
 
         def r(txn):
@@ -968,8 +975,7 @@ class Porter(object):
 
 
 class Progress(object):
-    """Used to report progress of the port
-    """
+    """Used to report progress of the port"""
 
     def __init__(self):
         self.tables = {}
@@ -994,8 +1000,7 @@ class Progress(object):
 
 
 class CursesProgress(Progress):
-    """Reports progress to a curses window
-    """
+    """Reports progress to a curses window"""
 
     def __init__(self, stdscr):
         self.stdscr = stdscr
@@ -1020,7 +1025,7 @@ class CursesProgress(Progress):
 
         self.total_processed = 0
         self.total_remaining = 0
-        for table, data in self.tables.items():
+        for data in self.tables.values():
             self.total_processed += data["num_done"] - data["start"]
             self.total_remaining += data["total"] - data["num_done"]
 
@@ -1111,8 +1116,7 @@ class CursesProgress(Progress):
 
 
 class TerminalProgress(Progress):
-    """Just prints progress to the terminal
-    """
+    """Just prints progress to the terminal"""
 
     def update(self, table, num_done):
         super(TerminalProgress, self).update(table, num_done)