diff options
Diffstat (limited to '')
-rw-r--r-- | scripts-dev/check_auth.py | 36 | ||||
-rw-r--r-- | scripts-dev/check_event_hash.py | 32 | ||||
-rw-r--r-- | scripts-dev/check_signature.py | 36 | ||||
-rw-r--r-- | scripts-dev/convert_server_keys.py | 40 | ||||
-rwxr-xr-x | scripts-dev/definitions.py | 54 | ||||
-rwxr-xr-x | scripts-dev/dump_macaroon.py | 13 | ||||
-rwxr-xr-x | scripts-dev/federation_client.py | 99 | ||||
-rw-r--r-- | scripts-dev/hash_history.py | 62 | ||||
-rwxr-xr-x | scripts-dev/list_url_patterns.py | 16 | ||||
-rw-r--r-- | scripts-dev/tail-synapse.py | 26 | ||||
-rwxr-xr-x | scripts/hash_password | 5 | ||||
-rwxr-xr-x | scripts/move_remote_media_to_new_store.py | 36 | ||||
-rwxr-xr-x | scripts/register_new_matrix_user | 76 | ||||
-rwxr-xr-x | scripts/synapse_port_db | 275 |
14 files changed, 404 insertions, 402 deletions
diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py index 4fa8792a5f..b3d11f49ec 100644 --- a/scripts-dev/check_auth.py +++ b/scripts-dev/check_auth.py @@ -1,21 +1,20 @@ -from synapse.events import FrozenEvent -from synapse.api.auth import Auth - -from mock import Mock +from __future__ import print_function import argparse import itertools import json import sys +from mock import Mock + +from synapse.api.auth import Auth +from synapse.events import FrozenEvent + def check_auth(auth, auth_chain, events): auth_chain.sort(key=lambda e: e.depth) - auth_map = { - e.event_id: e - for e in auth_chain - } + auth_map = {e.event_id: e for e in auth_chain} create_events = {} for e in auth_chain: @@ -25,31 +24,26 @@ def check_auth(auth, auth_chain, events): for e in itertools.chain(auth_chain, events): auth_events_list = [auth_map[i] for i, _ in e.auth_events] - auth_events = { - (e.type, e.state_key): e - for e in auth_events_list - } + auth_events = {(e.type, e.state_key): e for e in auth_events_list} auth_events[("m.room.create", "")] = create_events[e.room_id] try: auth.check(e, auth_events=auth_events) except Exception as ex: - print "Failed:", e.event_id, e.type, e.state_key - print "Auth_events:", auth_events - print ex - print json.dumps(e.get_dict(), sort_keys=True, indent=4) + print("Failed:", e.event_id, e.type, e.state_key) + print("Auth_events:", auth_events) + print(ex) + print(json.dumps(e.get_dict(), sort_keys=True, indent=4)) # raise - print "Success:", e.event_id, e.type, e.state_key + print("Success:", e.event_id, e.type, e.state_key) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - 'json', - nargs='?', - type=argparse.FileType('r'), - default=sys.stdin, + 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin ) args = parser.parse_args() diff --git a/scripts-dev/check_event_hash.py b/scripts-dev/check_event_hash.py index 7ccae34d48..8535f99697 100644 --- a/scripts-dev/check_event_hash.py +++ b/scripts-dev/check_event_hash.py @@ -1,10 +1,15 @@ -from synapse.crypto.event_signing import * -from unpaddedbase64 import encode_base64 - import argparse import hashlib -import sys import json +import logging +import sys + +from unpaddedbase64 import encode_base64 + +from synapse.crypto.event_signing import ( + check_event_content_hash, + compute_event_reference_hash, +) class dictobj(dict): @@ -24,27 +29,26 @@ class dictobj(dict): def main(): parser = argparse.ArgumentParser() - parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), - default=sys.stdin) + parser.add_argument( + "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + ) args = parser.parse_args() logging.basicConfig() event_json = dictobj(json.load(args.input_json)) - algorithms = { - "sha256": hashlib.sha256, - } + algorithms = {"sha256": hashlib.sha256} for alg_name in event_json.hashes: if check_event_content_hash(event_json, algorithms[alg_name]): - print "PASS content hash %s" % (alg_name,) + print("PASS content hash %s" % (alg_name,)) else: - print "FAIL content hash %s" % (alg_name,) + print("FAIL content hash %s" % (alg_name,)) for algorithm in algorithms.values(): name, h_bytes = compute_event_reference_hash(event_json, algorithm) - print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) + print("Reference hash %s: %s" % (name, encode_base64(h_bytes))) -if __name__=="__main__": - main() +if __name__ == "__main__": + main() diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index 079577908a..612f17ca7f 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -1,15 +1,15 @@ -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes, write_signing_keys -from unpaddedbase64 import decode_base64 - -import urllib2 +import argparse import json +import logging import sys +import urllib2 + import dns.resolver -import pprint -import argparse -import logging +from signedjson.key import decode_verify_key_bytes, write_signing_keys +from signedjson.sign import verify_signed_json +from unpaddedbase64 import decode_base64 + def get_targets(server_name): if ":" in server_name: @@ -23,6 +23,7 @@ def get_targets(server_name): except dns.resolver.NXDOMAIN: yield (server_name, 8448) + def get_server_keys(server_name, target, port): url = "https://%s:%i/_matrix/key/v1" % (target, port) keys = json.load(urllib2.urlopen(url)) @@ -33,12 +34,14 @@ def get_server_keys(server_name, target, port): verify_keys[key_id] = verify_key return verify_keys + def main(): parser = argparse.ArgumentParser() parser.add_argument("signature_name") - parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), - default=sys.stdin) + parser.add_argument( + "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + ) args = parser.parse_args() logging.basicConfig() @@ -48,24 +51,23 @@ def main(): for target, port in get_targets(server_name): try: keys = get_server_keys(server_name, target, port) - print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) + print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port)) write_signing_keys(sys.stdout, keys.values()) break - except: + except Exception: logging.exception("Error talking to %s:%s", target, port) json_to_check = json.load(args.input_json) - print "Checking JSON:" + print("Checking JSON:") for key_id in json_to_check["signatures"][args.signature_name]: try: key = keys[key_id] verify_signed_json(json_to_check, args.signature_name, key) - print "PASS %s" % (key_id,) - except: + print("PASS %s" % (key_id,)) + except Exception: logging.exception("Check for key %s failed" % (key_id,)) - print "FAIL %s" % (key_id,) + print("FAIL %s" % (key_id,)) if __name__ == '__main__': main() - diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 151551f22c..dde8596697 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -1,13 +1,21 @@ -import psycopg2 -import yaml -import sys +import hashlib import json +import sys import time -import hashlib -from unpaddedbase64 import encode_base64 + +import six + +import psycopg2 +import yaml +from canonicaljson import encode_canonical_json from signedjson.key import read_signing_keys from signedjson.sign import sign_json -from canonicaljson import encode_canonical_json +from unpaddedbase64 import encode_base64 + +if six.PY2: + db_type = six.moves.builtins.buffer +else: + db_type = memoryview def select_v1_keys(connection): @@ -39,7 +47,9 @@ def select_v2_json(connection): cursor.close() results = {} for server_name, key_id, key_json in rows: - results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8")) + results.setdefault(server_name, {})[key_id] = json.loads( + str(key_json).decode("utf-8") + ) return results @@ -47,10 +57,7 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate): return { "old_verify_keys": {}, "server_name": server_name, - "verify_keys": { - key_id: {"key": key} - for key_id, key in keys.items() - }, + "verify_keys": {key_id: {"key": key} for key_id, key in keys.items()}, "valid_until_ts": valid_until, "tls_fingerprints": [fingerprint(certificate)], } @@ -65,7 +72,7 @@ def rows_v2(server, json): valid_until = json["valid_until_ts"] key_json = encode_canonical_json(json) for key_id in json["verify_keys"]: - yield (server, key_id, "-", valid_until, valid_until, buffer(key_json)) + yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) def main(): @@ -87,7 +94,7 @@ def main(): result = {} for server in keys: - if not server in json: + if server not in json: v2_json = convert_v1_to_v2( server, valid_until, keys[server], certificates[server] ) @@ -96,10 +103,7 @@ def main(): yaml.safe_dump(result, sys.stdout, default_flow_style=False) - rows = list( - row for server, json in result.items() - for row in rows_v2(server, json) - ) + rows = list(row for server, json in result.items() for row in rows_v2(server, json)) cursor = connection.cursor() cursor.executemany( @@ -107,7 +111,7 @@ def main(): " server_name, key_id, from_server," " ts_added_ms, ts_valid_until_ms, key_json" ") VALUES (%s, %s, %s, %s, %s, %s)", - rows + rows, ) connection.commit() diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 47dac7772d..1deb0fe2b7 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -1,8 +1,16 @@ #! /usr/bin/python +from __future__ import print_function + +import argparse import ast +import os +import re +import sys + import yaml + class DefinitionVisitor(ast.NodeVisitor): def __init__(self): super(DefinitionVisitor, self).__init__() @@ -42,15 +50,18 @@ def non_empty(defs): functions = {name: non_empty(f) for name, f in defs['def'].items()} classes = {name: non_empty(f) for name, f in defs['class'].items()} result = {} - if functions: result['def'] = functions - if classes: result['class'] = classes + if functions: + result['def'] = functions + if classes: + result['class'] = classes names = defs['names'] uses = [] for name in names.get('Load', ()): if name not in names.get('Param', ()) and name not in names.get('Store', ()): uses.append(name) uses.extend(defs['attrs']) - if uses: result['uses'] = uses + if uses: + result['uses'] = uses result['names'] = names result['attrs'] = defs['attrs'] return result @@ -95,7 +106,6 @@ def used_names(prefix, item, defs, names): if __name__ == '__main__': - import sys, os, argparse, re parser = argparse.ArgumentParser(description='Find definitions.') parser.add_argument( @@ -105,24 +115,28 @@ if __name__ == '__main__': "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" ) parser.add_argument( - "--pattern", action="append", metavar="REGEXP", - help="Search for a pattern" + "--pattern", action="append", metavar="REGEXP", help="Search for a pattern" ) parser.add_argument( - "directories", nargs='+', metavar="DIR", - help="Directories to search for definitions" + "directories", + nargs='+', + metavar="DIR", + help="Directories to search for definitions", ) parser.add_argument( - "--referrers", default=0, type=int, - help="Include referrers up to the given depth" + "--referrers", + default=0, + type=int, + help="Include referrers up to the given depth", ) parser.add_argument( - "--referred", default=0, type=int, - help="Include referred down to the given depth" + "--referred", + default=0, + type=int, + help="Include referred down to the given depth", ) parser.add_argument( - "--format", default="yaml", - help="Output format, one of 'yaml' or 'dot'" + "--format", default="yaml", help="Output format, one of 'yaml' or 'dot'" ) args = parser.parse_args() @@ -162,7 +176,7 @@ if __name__ == '__main__': for used_by in entry.get("used", ()): referrers.add(used_by) for name, definition in names.items(): - if not name in referrers: + if name not in referrers: continue if ignore and any(pattern.match(name) for pattern in ignore): continue @@ -176,7 +190,7 @@ if __name__ == '__main__': for uses in entry.get("uses", ()): referred.add(uses) for name, definition in names.items(): - if not name in referred: + if name not in referred: continue if ignore and any(pattern.match(name) for pattern in ignore): continue @@ -185,12 +199,12 @@ if __name__ == '__main__': if args.format == 'yaml': yaml.dump(result, sys.stdout, default_flow_style=False) elif args.format == 'dot': - print "digraph {" + print("digraph {") for name, entry in result.items(): - print name + print(name) for used_by in entry.get("used", ()): if used_by in result: - print used_by, "->", name - print "}" + print(used_by, "->", name) + print("}") else: raise ValueError("Unknown format %r" % (args.format)) diff --git a/scripts-dev/dump_macaroon.py b/scripts-dev/dump_macaroon.py index fcc5568835..22b30fa78e 100755 --- a/scripts-dev/dump_macaroon.py +++ b/scripts-dev/dump_macaroon.py @@ -1,8 +1,11 @@ #!/usr/bin/env python2 -import pymacaroons +from __future__ import print_function + import sys +import pymacaroons + if len(sys.argv) == 1: sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],)) sys.exit(1) @@ -11,14 +14,14 @@ macaroon_string = sys.argv[1] key = sys.argv[2] if len(sys.argv) > 2 else None macaroon = pymacaroons.Macaroon.deserialize(macaroon_string) -print macaroon.inspect() +print(macaroon.inspect()) -print "" +print("") verifier = pymacaroons.Verifier() verifier.satisfy_general(lambda c: True) try: verifier.verify(macaroon, key) - print "Signature is correct" + print("Signature is correct") except Exception as e: - print str(e) + print(str(e)) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index d2acc7654d..2566ce7cef 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -18,21 +18,21 @@ from __future__ import print_function import argparse +import base64 +import json +import sys from urlparse import urlparse, urlunparse import nacl.signing -import json -import base64 import requests -import sys - -from requests.adapters import HTTPAdapter import srvlookup import yaml +from requests.adapters import HTTPAdapter # uncomment the following to enable debug logging of http requests -#from httplib import HTTPConnection -#HTTPConnection.debuglevel = 1 +# from httplib import HTTPConnection +# HTTPConnection.debuglevel = 1 + def encode_base64(input_bytes): """Encode bytes as a base64 string without any padding.""" @@ -58,15 +58,15 @@ def decode_base64(input_string): def encode_canonical_json(value): return json.dumps( - value, - # Encode code-points outside of ASCII as UTF-8 rather than \u escapes - ensure_ascii=False, - # Remove unecessary white space. - separators=(',',':'), - # Sort the keys of dictionaries. - sort_keys=True, - # Encode the resulting unicode as UTF-8 bytes. - ).encode("UTF-8") + value, + # Encode code-points outside of ASCII as UTF-8 rather than \u escapes + ensure_ascii=False, + # Remove unecessary white space. + separators=(',', ':'), + # Sort the keys of dictionaries. + sort_keys=True, + # Encode the resulting unicode as UTF-8 bytes. + ).encode("UTF-8") def sign_json(json_object, signing_key, signing_name): @@ -88,6 +88,7 @@ def sign_json(json_object, signing_key, signing_name): NACL_ED25519 = "ed25519" + def decode_signing_key_base64(algorithm, version, key_base64): """Decode a base64 encoded signing key Args: @@ -143,14 +144,12 @@ def request_json(method, origin_name, origin_key, destination, path, content): authorization_headers = [] for key, sig in signed_json["signatures"][origin_name].items(): - header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - origin_name, key, sig, - ) + header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig) authorization_headers.append(bytes(header)) - print ("Authorization: %s" % header, file=sys.stderr) + print("Authorization: %s" % header, file=sys.stderr) dest = "matrix://%s%s" % (destination, path) - print ("Requesting %s" % dest, file=sys.stderr) + print("Requesting %s" % dest, file=sys.stderr) s = requests.Session() s.mount("matrix://", MatrixConnectionAdapter()) @@ -158,10 +157,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): result = s.request( method=method, url=dest, - headers={ - "Host": destination, - "Authorization": authorization_headers[0] - }, + headers={"Host": destination, "Authorization": authorization_headers[0]}, verify=False, data=content, ) @@ -171,50 +167,50 @@ def request_json(method, origin_name, origin_key, destination, path, content): def main(): parser = argparse.ArgumentParser( - description= - "Signs and sends a federation request to a matrix homeserver", + description="Signs and sends a federation request to a matrix homeserver" ) parser.add_argument( - "-N", "--server-name", + "-N", + "--server-name", help="Name to give as the local homeserver. If unspecified, will be " - "read from the config file.", + "read from the config file.", ) parser.add_argument( - "-k", "--signing-key-path", + "-k", + "--signing-key-path", help="Path to the file containing the private ed25519 key to sign the " - "request with.", + "request with.", ) parser.add_argument( - "-c", "--config", + "-c", + "--config", default="homeserver.yaml", help="Path to server config file. Ignored if --server-name and " - "--signing-key-path are both given.", + "--signing-key-path are both given.", ) parser.add_argument( - "-d", "--destination", + "-d", + "--destination", default="matrix.org", help="name of the remote homeserver. We will do SRV lookups and " - "connect appropriately.", + "connect appropriately.", ) parser.add_argument( - "-X", "--method", + "-X", + "--method", help="HTTP method to use for the request. Defaults to GET if --data is" - "unspecified, POST if it is." + "unspecified, POST if it is.", ) - parser.add_argument( - "--body", - help="Data to send as the body of the HTTP request" - ) + parser.add_argument("--body", help="Data to send as the body of the HTTP request") parser.add_argument( - "path", - help="request path. We will add '/_matrix/federation/v1/' to this." + "path", help="request path. We will add '/_matrix/federation/v1/' to this." ) args = parser.parse_args() @@ -227,13 +223,15 @@ def main(): result = request_json( args.method, - args.server_name, key, args.destination, + args.server_name, + key, + args.destination, "/_matrix/federation/v1/" + args.path, content=args.body, ) json.dump(result, sys.stdout) - print ("") + print("") def read_args_from_config(args): @@ -253,7 +251,7 @@ class MatrixConnectionAdapter(HTTPAdapter): return s, 8448 if ":" in s: - out = s.rsplit(":",1) + out = s.rsplit(":", 1) try: port = int(out[1]) except ValueError: @@ -263,7 +261,7 @@ class MatrixConnectionAdapter(HTTPAdapter): try: srv = srvlookup.lookup("matrix", "tcp", s)[0] return srv.host, srv.port - except: + except Exception: return s, 8448 def get_connection(self, url, proxies=None): @@ -272,10 +270,9 @@ class MatrixConnectionAdapter(HTTPAdapter): (host, port) = self.lookup(parsed.netloc) netloc = "%s:%d" % (host, port) print("Connecting to %s" % (netloc,), file=sys.stderr) - url = urlunparse(( - "https", netloc, parsed.path, parsed.params, parsed.query, - parsed.fragment, - )) + url = urlunparse( + ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) + ) return super(MatrixConnectionAdapter, self).get_connection(url, proxies) diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index 616d6a10e7..514d80fa60 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -1,23 +1,31 @@ -from synapse.storage.pdu import PduStore -from synapse.storage.signatures import SignatureStore -from synapse.storage._base import SQLBaseStore -from synapse.federation.units import Pdu -from synapse.crypto.event_signing import ( - add_event_pdu_content_hash, compute_pdu_event_reference_hash -) -from synapse.api.events.utils import prune_pdu -from unpaddedbase64 import encode_base64, decode_base64 -from canonicaljson import encode_canonical_json +from __future__ import print_function + import sqlite3 import sys +from unpaddedbase64 import decode_base64, encode_base64 + +from synapse.crypto.event_signing import ( + add_event_pdu_content_hash, + compute_pdu_event_reference_hash, +) +from synapse.federation.units import Pdu +from synapse.storage._base import SQLBaseStore +from synapse.storage.pdu import PduStore +from synapse.storage.signatures import SignatureStore + + class Store(object): _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] - _get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] + _get_pdu_origin_signatures_txn = SignatureStore.__dict__[ + "_get_pdu_origin_signatures_txn" + ] _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] - _store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] + _store_pdu_reference_hash_txn = SignatureStore.__dict__[ + "_store_pdu_reference_hash_txn" + ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] @@ -26,9 +34,7 @@ store = Store() def select_pdus(cursor): - cursor.execute( - "SELECT pdu_id, origin FROM pdus ORDER BY depth ASC" - ) + cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC") ids = cursor.fetchall() @@ -41,23 +47,30 @@ def select_pdus(cursor): for pdu in pdus: try: if pdu.prev_pdus: - print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus + print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) for pdu_id, origin, hashes in pdu.prev_pdus: ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] hashes[ref_alg] = encode_base64(ref_hsh) - store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) - print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus + store._store_prev_pdu_hash_txn( + cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh + ) + print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus) pdu = add_event_pdu_content_hash(pdu) ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) - store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) + store._store_pdu_reference_hash_txn( + cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh + ) for alg, hsh_base64 in pdu.hashes.items(): - print alg, hsh_base64 - store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) + print(alg, hsh_base64) + store._store_pdu_content_hash_txn( + cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64) + ) + + except Exception: + print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus) - except: - print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus def main(): conn = sqlite3.connect(sys.argv[1]) @@ -65,5 +78,6 @@ def main(): select_pdus(cursor) conn.commit() -if __name__=='__main__': + +if __name__ == '__main__': main() diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index 58d40c4ff4..da027be26e 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -1,18 +1,17 @@ #! /usr/bin/python -import ast import argparse +import ast import os import sys + import yaml PATTERNS_V1 = [] PATTERNS_V2 = [] -RESULT = { - "v1": PATTERNS_V1, - "v2": PATTERNS_V2, -} +RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2} + class CallVisitor(ast.NodeVisitor): def visit_Call(self, node): @@ -21,7 +20,6 @@ class CallVisitor(ast.NodeVisitor): else: return - if name == "client_path_patterns": PATTERNS_V1.append(node.args[0].s) elif name == "client_v2_patterns": @@ -42,8 +40,10 @@ def find_patterns_in_file(filepath): parser = argparse.ArgumentParser(description='Find url patterns.') parser.add_argument( - "directories", nargs='+', metavar="DIR", - help="Directories to search for definitions" + "directories", + nargs='+', + metavar="DIR", + help="Directories to search for definitions", ) args = parser.parse_args() diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py index 18be711e92..7c9985d9f0 100644 --- a/scripts-dev/tail-synapse.py +++ b/scripts-dev/tail-synapse.py @@ -1,8 +1,9 @@ -import requests import collections +import json import sys import time -import json + +import requests Entry = collections.namedtuple("Entry", "name position rows") @@ -30,11 +31,11 @@ def parse_response(content): def replicate(server, streams): - return parse_response(requests.get( - server + "/_synapse/replication", - verify=False, - params=streams - ).content) + return parse_response( + requests.get( + server + "/_synapse/replication", verify=False, params=streams + ).content + ) def main(): @@ -45,16 +46,16 @@ def main(): try: streams = { row.name: row.position - for row in replicate(server, {"streams":"-1"})["streams"].rows + for row in replicate(server, {"streams": "-1"})["streams"].rows } - except requests.exceptions.ConnectionError as e: + except requests.exceptions.ConnectionError: time.sleep(0.1) while True: try: results = replicate(server, streams) - except: - sys.stdout.write("connection_lost("+ repr(streams) + ")\n") + except Exception: + sys.stdout.write("connection_lost(" + repr(streams) + ")\n") break for update in results.values(): for row in update.rows: @@ -62,6 +63,5 @@ def main(): streams[update.name] = update.position - -if __name__=='__main__': +if __name__ == '__main__': main() diff --git a/scripts/hash_password b/scripts/hash_password index 215ab25cfe..a62bb5aa83 100755 --- a/scripts/hash_password +++ b/scripts/hash_password @@ -1,12 +1,10 @@ #!/usr/bin/env python import argparse - +import getpass import sys import bcrypt -import getpass - import yaml bcrypt_rounds=12 @@ -52,4 +50,3 @@ if __name__ == "__main__": password = prompt_for_pass() print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) - diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 7914ead889..e630936f78 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -36,12 +36,9 @@ from __future__ import print_function import argparse import logging - -import sys - import os - import shutil +import sys from synapse.rest.media.v1.filepath import MediaFilePaths @@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths): if not os.path.exists(original_file): logger.warn( "Original for %s/%s (%s) does not exist", - origin_server, file_id, original_file, + origin_server, + file_id, + original_file, ) else: mkdir_and_move( - original_file, - dest_paths.remote_media_filepath(origin_server, file_id), + original_file, dest_paths.remote_media_filepath(origin_server, file_id) ) # now look for thumbnails - original_thumb_dir = src_paths.remote_media_thumbnail_dir( - origin_server, file_id, - ) + original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id) if not os.path.exists(original_thumb_dir): return mkdir_and_move( original_thumb_dir, - dest_paths.remote_media_thumbnail_dir(origin_server, file_id) + dest_paths.remote_media_thumbnail_dir(origin_server, file_id), ) @@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file): if __name__ == "__main__": parser = argparse.ArgumentParser( - description=__doc__, - formatter_class = argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "-v", action='store_true', help='enable debug logging') - parser.add_argument( - "src_repo", - help="Path to source content repo", - ) - parser.add_argument( - "dest_repo", - help="Path to source content repo", + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) + parser.add_argument("-v", action='store_true', help='enable debug logging') + parser.add_argument("src_repo", help="Path to source content repo") + parser.add_argument("dest_repo", help="Path to source content repo") args = parser.parse_args() logging_config = { "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", } logging.basicConfig(**logging_config) diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 91bdb3a25b..89143c5d59 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import argparse import getpass @@ -22,19 +23,23 @@ import hmac import json import sys import urllib2 + +from six import input + import yaml def request_registration(user, password, server_location, shared_secret, admin=False): req = urllib2.Request( "%s/_matrix/client/r0/admin/register" % (server_location,), - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, ) try: if sys.version_info[:3] >= (2, 7, 9): # As of version 2.7.9, urllib2 now checks SSL certs import ssl + f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) else: f = urllib2.urlopen(req) @@ -42,18 +47,15 @@ def request_registration(user, password, server_location, shared_secret, admin=F f.close() nonce = json.loads(body)["nonce"] except urllib2.HTTPError as e: - print "ERROR! Received %d %s" % (e.code, e.reason,) + print("ERROR! Received %d %s" % (e.code, e.reason)) if 400 <= e.code < 500: if e.info().type == "application/json": resp = json.load(e) if "error" in resp: - print resp["error"] + print(resp["error"]) sys.exit(1) - mac = hmac.new( - key=shared_secret, - digestmod=hashlib.sha1, - ) + mac = hmac.new(key=shared_secret, digestmod=hashlib.sha1) mac.update(nonce) mac.update("\x00") @@ -75,30 +77,31 @@ def request_registration(user, password, server_location, shared_secret, admin=F server_location = server_location.rstrip("/") - print "Sending registration request..." + print("Sending registration request...") req = urllib2.Request( "%s/_matrix/client/r0/admin/register" % (server_location,), data=json.dumps(data), - headers={'Content-Type': 'application/json'} + headers={'Content-Type': 'application/json'}, ) try: if sys.version_info[:3] >= (2, 7, 9): # As of version 2.7.9, urllib2 now checks SSL certs import ssl + f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) else: f = urllib2.urlopen(req) f.read() f.close() - print "Success." + print("Success.") except urllib2.HTTPError as e: - print "ERROR! Received %d %s" % (e.code, e.reason,) + print("ERROR! Received %d %s" % (e.code, e.reason)) if 400 <= e.code < 500: if e.info().type == "application/json": resp = json.load(e) if "error" in resp: - print resp["error"] + print(resp["error"]) sys.exit(1) @@ -106,35 +109,35 @@ def register_new_user(user, password, server_location, shared_secret, admin): if not user: try: default_user = getpass.getuser() - except: + except Exception: default_user = None if default_user: - user = raw_input("New user localpart [%s]: " % (default_user,)) + user = input("New user localpart [%s]: " % (default_user,)) if not user: user = default_user else: - user = raw_input("New user localpart: ") + user = input("New user localpart: ") if not user: - print "Invalid user name" + print("Invalid user name") sys.exit(1) if not password: password = getpass.getpass("Password: ") if not password: - print "Password cannot be blank." + print("Password cannot be blank.") sys.exit(1) confirm_password = getpass.getpass("Confirm password: ") if password != confirm_password: - print "Passwords do not match" + print("Passwords do not match") sys.exit(1) if admin is None: - admin = raw_input("Make admin [no]: ") + admin = input("Make admin [no]: ") if admin in ("y", "yes", "true"): admin = True else: @@ -146,42 +149,51 @@ def register_new_user(user, password, server_location, shared_secret, admin): if __name__ == "__main__": parser = argparse.ArgumentParser( description="Used to register new users with a given home server when" - " registration has been disabled. The home server must be" - " configured with the 'registration_shared_secret' option" - " set.", + " registration has been disabled. The home server must be" + " configured with the 'registration_shared_secret' option" + " set." ) parser.add_argument( - "-u", "--user", + "-u", + "--user", default=None, help="Local part of the new user. Will prompt if omitted.", ) parser.add_argument( - "-p", "--password", + "-p", + "--password", default=None, help="New password for user. Will prompt if omitted.", ) admin_group = parser.add_mutually_exclusive_group() admin_group.add_argument( - "-a", "--admin", + "-a", + "--admin", action="store_true", - help="Register new user as an admin. Will prompt if --no-admin is not set either.", + help=( + "Register new user as an admin. " + "Will prompt if --no-admin is not set either." + ), ) admin_group.add_argument( "--no-admin", action="store_true", - help="Register new user as a regular user. Will prompt if --admin is not set either.", + help=( + "Register new user as a regular user. " + "Will prompt if --admin is not set either." + ), ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( - "-c", "--config", + "-c", + "--config", type=argparse.FileType('r'), help="Path to server config file. Used to read in shared secret.", ) group.add_argument( - "-k", "--shared-secret", - help="Shared secret as defined in server config file.", + "-k", "--shared-secret", help="Shared secret as defined in server config file." ) parser.add_argument( @@ -189,7 +201,7 @@ if __name__ == "__main__": default="https://localhost:8448", nargs='?', help="URL to use to talk to the home server. Defaults to " - " 'https://localhost:8448'.", + " 'https://localhost:8448'.", ) args = parser.parse_args() @@ -198,7 +210,7 @@ if __name__ == "__main__": config = yaml.safe_load(args.config) secret = config.get("registration_shared_secret", None) if not secret: - print "No 'registration_shared_secret' defined in config." + print("No 'registration_shared_secret' defined in config.") sys.exit(1) else: secret = args.shared_secret diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index b9b828c154..3c7b606323 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -15,23 +15,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor -from twisted.enterprise import adbapi - -from synapse.storage._base import LoggingTransaction, SQLBaseStore -from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database - import argparse import curses import logging import sys import time import traceback -import yaml from six import string_types +import yaml + +from twisted.enterprise import adbapi +from twisted.internet import defer, reactor + +from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.engines import create_engine +from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("synapse_port_db") @@ -105,6 +105,7 @@ class Store(object): *All* database interactions should go through this object. """ + def __init__(self, db_pool, engine): self.db_pool = db_pool self.database_engine = engine @@ -135,7 +136,8 @@ class Store(object): txn = conn.cursor() return func( LoggingTransaction(txn, desc, self.database_engine, [], []), - *args, **kwargs + *args, + **kwargs ) except self.database_engine.module.DatabaseError as e: if self.database_engine.is_deadlock(e): @@ -158,22 +160,20 @@ class Store(object): def r(txn): txn.execute(sql, args) return txn.fetchall() + return self.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, ", ".join(k for k in headers), - ", ".join("%s" for _ in headers) + ", ".join("%s" for _ in headers), ) try: txn.executemany(sql, rows) - except: - logger.exception( - "Failed to insert: %s", - table, - ) + except Exception: + logger.exception("Failed to insert: %s", table) raise @@ -206,7 +206,7 @@ class Porter(object): "table_name": table, "forward_rowid": 1, "backward_rowid": 0, - } + }, ) forward_chunk = 1 @@ -221,10 +221,10 @@ class Porter(object): table, forward_chunk, backward_chunk ) else: + def delete_all(txn): txn.execute( - "DELETE FROM port_from_sqlite3 WHERE table_name = %s", - (table,) + "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,) ) txn.execute("TRUNCATE %s CASCADE" % (table,)) @@ -232,11 +232,7 @@ class Porter(object): yield self.postgres_store._simple_insert( table="port_from_sqlite3", - values={ - "table_name": table, - "forward_rowid": 1, - "backward_rowid": 0, - } + values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) forward_chunk = 1 @@ -251,12 +247,16 @@ class Porter(object): ) @defer.inlineCallbacks - def handle_table(self, table, postgres_size, table_size, forward_chunk, - backward_chunk): + def handle_table( + self, table, postgres_size, table_size, forward_chunk, backward_chunk + ): logger.info( "Table %s: %i/%i (rows %i-%i) already ported", - table, postgres_size, table_size, - backward_chunk+1, forward_chunk-1, + table, + postgres_size, + table_size, + backward_chunk + 1, + forward_chunk - 1, ) if not table_size: @@ -271,7 +271,9 @@ class Porter(object): return if table in ( - "user_directory", "user_directory_search", "users_who_share_rooms", + "user_directory", + "user_directory_search", + "users_who_share_rooms", "users_in_pubic_room", ): # We don't port these tables, as they're a faff and we can regenreate @@ -283,37 +285,35 @@ class Porter(object): # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. yield self.postgres_store._simple_insert( - table=table, - values={"stream_id": None}, + table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done return forward_select = ( - "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" - % (table,) + "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,) ) backward_select = ( - "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" - % (table,) + "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,) ) do_forward = [True] do_backward = [True] while True: + def r(txn): forward_rows = [] backward_rows = [] if do_forward[0]: - txn.execute(forward_select, (forward_chunk, self.batch_size,)) + txn.execute(forward_select, (forward_chunk, self.batch_size)) forward_rows = txn.fetchall() if not forward_rows: do_forward[0] = False if do_backward[0]: - txn.execute(backward_select, (backward_chunk, self.batch_size,)) + txn.execute(backward_select, (backward_chunk, self.batch_size)) backward_rows = txn.fetchall() if not backward_rows: do_backward[0] = False @@ -325,9 +325,7 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction( - "select", r - ) + headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) if frows or brows: if frows: @@ -339,9 +337,7 @@ class Porter(object): rows = self._convert_rows(table, headers, rows) def insert(txn): - self.postgres_store.insert_many_txn( - txn, table, headers[1:], rows - ) + self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) self.postgres_store._simple_update_one_txn( txn, @@ -362,8 +358,9 @@ class Porter(object): return @defer.inlineCallbacks - def handle_search_table(self, postgres_size, table_size, forward_chunk, - backward_chunk): + def handle_search_table( + self, postgres_size, table_size, forward_chunk, backward_chunk + ): select = ( "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" " FROM event_search as es" @@ -373,8 +370,9 @@ class Porter(object): ) while True: + def r(txn): - txn.execute(select, (forward_chunk, self.batch_size,)) + txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() headers = [column[0] for column in txn.description] @@ -402,18 +400,21 @@ class Porter(object): else: rows_dict.append(d) - txn.executemany(sql, [ - ( - row["event_id"], - row["room_id"], - row["key"], - row["sender"], - row["value"], - row["origin_server_ts"], - row["stream_ordering"], - ) - for row in rows_dict - ]) + txn.executemany( + sql, + [ + ( + row["event_id"], + row["room_id"], + row["key"], + row["sender"], + row["value"], + row["origin_server_ts"], + row["stream_ordering"], + ) + for row in rows_dict + ], + ) self.postgres_store._simple_update_one_txn( txn, @@ -437,7 +438,8 @@ class Porter(object): def setup_db(self, db_config, database_engine): db_conn = database_engine.module.connect( **{ - k: v for k, v in db_config.get("args", {}).items() + k: v + for k, v in db_config.get("args", {}).items() if not k.startswith("cp_") } ) @@ -450,13 +452,11 @@ class Porter(object): def run(self): try: sqlite_db_pool = adbapi.ConnectionPool( - self.sqlite_config["name"], - **self.sqlite_config["args"] + self.sqlite_config["name"], **self.sqlite_config["args"] ) postgres_db_pool = adbapi.ConnectionPool( - self.postgres_config["name"], - **self.postgres_config["args"] + self.postgres_config["name"], **self.postgres_config["args"] ) sqlite_engine = create_engine(sqlite_config) @@ -465,9 +465,7 @@ class Porter(object): self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) - yield self.postgres_store.execute( - postgres_engine.check_database - ) + yield self.postgres_store.execute(postgres_engine.check_database) # Step 1. Set up databases. self.progress.set_state("Preparing SQLite3") @@ -477,6 +475,7 @@ class Porter(object): self.setup_db(postgres_config, postgres_engine) self.progress.set_state("Creating port tables") + def create_port_table(txn): txn.execute( "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" @@ -501,10 +500,9 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction( - "alter_table", alter_table - ) - except Exception as e: + yield self.postgres_store.runInteraction("alter_table", alter_table) + except Exception: + # On Error Resume Next pass yield self.postgres_store.runInteraction( @@ -514,11 +512,7 @@ class Porter(object): # Step 2. Get tables. self.progress.set_state("Fetching tables") sqlite_tables = yield self.sqlite_store._simple_select_onecol( - table="sqlite_master", - keyvalues={ - "type": "table", - }, - retcol="name", + table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) postgres_tables = yield self.postgres_store._simple_select_onecol( @@ -545,18 +539,14 @@ class Porter(object): # Step 4. Do the copying. self.progress.set_state("Copying to postgres") yield defer.gatherResults( - [ - self.handle_table(*res) - for res in setup_res - ], - consumeErrors=True, + [self.handle_table(*res) for res in setup_res], consumeErrors=True ) # Step 5. Do final post-processing yield self._setup_state_group_id_seq() self.progress.done() - except: + except Exception: global end_error_exec_info end_error_exec_info = sys.exc_info() logger.exception("") @@ -566,9 +556,7 @@ class Porter(object): def _convert_rows(self, table, headers, rows): bool_col_names = BOOLEAN_COLUMNS.get(table, []) - bool_cols = [ - i for i, h in enumerate(headers) if h in bool_col_names - ] + bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names] class BadValueException(Exception): pass @@ -577,18 +565,21 @@ class Porter(object): if j in bool_cols: return bool(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) - raise BadValueException(); + logger.warn( + "DROPPING ROW: NUL value in table %s col %s: %r", + table, + headers[j], + col, + ) + raise BadValueException() return col outrows = [] for i, row in enumerate(rows): try: - outrows.append(tuple( - conv(j, col) - for j, col in enumerate(row) - if j > 0 - )) + outrows.append( + tuple(conv(j, col) for j, col in enumerate(row) if j > 0) + ) except BadValueException: pass @@ -616,9 +607,7 @@ class Porter(object): return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction( - "select", r, - ) + headers, rows = yield self.sqlite_store.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -639,7 +628,7 @@ class Porter(object): txn.execute( "SELECT rowid FROM sent_transactions WHERE ts >= ?" " ORDER BY rowid ASC LIMIT 1", - (yesterday,) + (yesterday,), ) rows = txn.fetchall() @@ -657,21 +646,17 @@ class Porter(object): "table_name": "sent_transactions", "forward_rowid": next_chunk, "backward_rowid": 0, - } + }, ) def get_sent_table_size(txn): txn.execute( - "SELECT count(*) FROM sent_transactions" - " WHERE ts >= ?", - (yesterday,) + "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) size, = txn.fetchone() return int(size) - remaining_count = yield self.sqlite_store.execute( - get_sent_table_size - ) + remaining_count = yield self.sqlite_store.execute(get_sent_table_size) total_count = remaining_count + inserted_rows @@ -680,13 +665,11 @@ class Porter(object): @defer.inlineCallbacks def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): frows = yield self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), - forward_chunk, + "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk ) brows = yield self.sqlite_store.execute_sql( - "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), - backward_chunk, + "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk ) defer.returnValue(frows[0][0] + brows[0][0]) @@ -694,7 +677,7 @@ class Porter(object): @defer.inlineCallbacks def _get_already_ported_count(self, table): rows = yield self.postgres_store.execute_sql( - "SELECT count(*) FROM %s" % (table,), + "SELECT count(*) FROM %s" % (table,) ) defer.returnValue(rows[0][0]) @@ -717,22 +700,21 @@ class Porter(object): def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0]+1 - txn.execute( - "ALTER SEQUENCE state_group_id_seq RESTART WITH %s", - (next_id,), - ) + next_id = txn.fetchone()[0] + 1 + txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) + return self.postgres_store.runInteraction("setup_state_group_id_seq", r) ############################################## -###### The following is simply UI stuff ###### +# The following is simply UI stuff ############################################## class Progress(object): """Used to report progress of the port """ + def __init__(self): self.tables = {} @@ -758,6 +740,7 @@ class Progress(object): class CursesProgress(Progress): """Reports progress to a curses window """ + def __init__(self, stdscr): self.stdscr = stdscr @@ -801,7 +784,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds,) + duration_str = '%02dm %02ds' % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -814,16 +797,12 @@ class CursesProgress(Progress): est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" - status = ( - "Time spent: %s (est. remaining: %s)" - % (duration_str, est_remaining_str,) + status = "Time spent: %s (est. remaining: %s)" % ( + duration_str, + est_remaining_str, ) - self.stdscr.addstr( - 0, 0, - status, - curses.A_BOLD, - ) + self.stdscr.addstr(0, 0, status, curses.A_BOLD) max_len = max([len(t) for t in self.tables.keys()]) @@ -831,9 +810,7 @@ class CursesProgress(Progress): middle_space = 1 items = self.tables.items() - items.sort( - key=lambda i: (i[1]["perc"], i[0]), - ) + items.sort(key=lambda i: (i[1]["perc"], i[0])) for i, (table, data) in enumerate(items): if i + 2 >= rows: @@ -844,9 +821,7 @@ class CursesProgress(Progress): color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) self.stdscr.addstr( - i + 2, left_margin + max_len - len(table), - table, - curses.A_BOLD | color, + i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color ) size = 20 @@ -857,15 +832,13 @@ class CursesProgress(Progress): ) self.stdscr.addstr( - i + 2, left_margin + max_len + middle_space, + i + 2, + left_margin + max_len + middle_space, "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), ) if self.finished: - self.stdscr.addstr( - rows - 1, 0, - "Press any key to exit...", - ) + self.stdscr.addstr(rows - 1, 0, "Press any key to exit...") self.stdscr.refresh() self.last_update = time.time() @@ -877,29 +850,25 @@ class CursesProgress(Progress): def set_state(self, state): self.stdscr.clear() - self.stdscr.addstr( - 0, 0, - state + "...", - curses.A_BOLD, - ) + self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD) self.stdscr.refresh() class TerminalProgress(Progress): """Just prints progress to the terminal """ + def update(self, table, num_done): super(TerminalProgress, self).update(table, num_done) data = self.tables[table] - print "%s: %d%% (%d/%d)" % ( - table, data["perc"], - data["num_done"], data["total"], + print( + "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"]) ) def set_state(self, state): - print state + "..." + print(state + "...") ############################################## @@ -909,34 +878,38 @@ class TerminalProgress(Progress): if __name__ == "__main__": parser = argparse.ArgumentParser( description="A script to port an existing synapse SQLite database to" - " a new PostgreSQL database." + " a new PostgreSQL database." ) parser.add_argument("-v", action='store_true') parser.add_argument( - "--sqlite-database", required=True, + "--sqlite-database", + required=True, help="The snapshot of the SQLite database file. This must not be" - " currently used by a running synapse server" + " currently used by a running synapse server", ) parser.add_argument( - "--postgres-config", type=argparse.FileType('r'), required=True, - help="The database config file for the PostgreSQL database" + "--postgres-config", + type=argparse.FileType('r'), + required=True, + help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', - help="display a curses based progress UI" + "--curses", action='store_true', help="display a curses based progress UI" ) parser.add_argument( - "--batch-size", type=int, default=1000, + "--batch-size", + type=int, + default=1000, help="The number of rows to select from the SQLite table each" - " iteration [default=1000]", + " iteration [default=1000]", ) args = parser.parse_args() logging_config = { "level": logging.DEBUG if args.v else logging.INFO, - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" + "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", } if args.curses: |