diff --git a/scripts/port_from_sqlite_to_postgres.py b/scripts/port_from_sqlite_to_postgres.py
index 4b3fd9e529..fc1603c1c9 100644
--- a/scripts/port_from_sqlite_to_postgres.py
+++ b/scripts/port_from_sqlite_to_postgres.py
@@ -20,38 +20,17 @@ from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
import argparse
-import itertools
+import curses
import logging
-import types
+import sys
+import time
+import traceback
import yaml
logger = logging.getLogger("port_from_sqlite_to_postgres")
-BINARY_COLUMNS = {
- "event_content_hashes": ["hash"],
- "event_reference_hashes": ["hash"],
- "event_signatures": ["signature"],
- "event_edge_hashes": ["hash"],
- "events": ["content", "unrecognized_keys"],
- "event_json": ["internal_metadata", "json"],
- "application_services_txns": ["event_ids"],
- "received_transactions": ["response_json"],
- "sent_transactions": ["response_json"],
- "server_tls_certificates": ["tls_certificate"],
- "server_signature_keys": ["verify_key"],
- "pushers": ["pushkey", "data"],
- "user_filters": ["filter_json"],
-}
-
-UNICODE_COLUMNS = {
- "events": ["content", "unrecognized_keys"],
- "event_json": ["internal_metadata", "json"],
- "users": ["password_hash"],
-}
-
-
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"],
"rooms": ["is_public"],
@@ -91,7 +70,15 @@ APPEND_ONLY_TABLES = [
]
+end_error_exec_info = None
+
+
class Store(object):
+ """This object is used to pull out some of the convenience API from the
+ Storage layer.
+
+ *All* database interactions should go through this object.
+ """
def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
@@ -130,11 +117,14 @@ class Store(object):
continue
raise
except Exception as e:
- logger.debug("[TXN FAIL] {%s}", desc, e)
+ logger.debug("[TXN FAIL] {%s} %s", desc, e)
raise
return self.db_pool.runWithConnection(r)
+ def execute(self, f):
+ return self.runInteraction(f.__name__, f)
+
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
@@ -152,205 +142,435 @@ class Store(object):
raise
+class Progress(object):
+ """Used to report progress of the port
+ """
+ def __init__(self):
+ self.tables = {}
+
+ self.start_time = int(time.time())
+
+ def add_table(self, table, cur, size):
+ self.tables[table] = {
+ "start": cur,
+ "num_done": cur,
+ "total": size,
+ "perc": int(cur * 100 / size),
+ }
+
+ def update(self, table, num_done):
+ data = self.tables[table]
+ data["num_done"] = num_done
+ data["perc"] = int(num_done * 100 / data["total"])
+
+ def done(self):
+ pass
+
-def chunks(n):
- for i in itertools.count(0, n):
- yield range(i, i+n)
+class CursesProgress(Progress):
+ """Reports progress to a curses window
+ """
+ def __init__(self, stdscr):
+ self.stdscr = stdscr
+ curses.use_default_colors()
+ curses.curs_set(0)
-@defer.inlineCallbacks
-def handle_table(table, sqlite_store, postgres_store):
- if table in APPEND_ONLY_TABLES:
- # It's safe to just carry on inserting.
- next_chunk = yield postgres_store._simple_select_one_onecol(
- table="port_from_sqlite3",
- keyvalues={"table_name": table},
- retcol="rowid",
- allow_none=True,
+ curses.init_pair(1, curses.COLOR_RED, -1)
+ curses.init_pair(2, curses.COLOR_GREEN, -1)
+
+ self.last_update = 0
+
+ self.finished = False
+
+ super(CursesProgress, self).__init__()
+
+ def update(self, table, num_done):
+ super(CursesProgress, self).update(table, num_done)
+
+ self.render()
+
+ def render(self, force=False):
+ now = time.time()
+
+ if not force and now - self.last_update < 0.2:
+ # reactor.callLater(1, self.render)
+ return
+
+ self.stdscr.clear()
+
+ rows, cols = self.stdscr.getmaxyx()
+
+ duration = int(now) - int(self.start_time)
+
+ minutes, seconds = divmod(duration, 60)
+ duration_str = '%02dm %02ds' % (minutes, seconds,)
+
+ if self.finished:
+ status = "Time spent: %s (Done!)" % (duration_str,)
+ else:
+ min_perc = min(
+ (v["num_done"] - v["start"]) * 100. / (v["total"] - v["start"])
+ if v["total"] - v["start"] else 100
+ for v in self.tables.values()
+ )
+ if min_perc > 0:
+ est_remaining = (int(now) - self.start_time) * 100 / min_perc
+ est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
+ else:
+ est_remaining_str = "Unknown"
+ status = (
+ "Time spent: %s (est. remaining: %s)"
+ % (duration_str, est_remaining_str,)
+ )
+
+ self.stdscr.addstr(
+ 0, 0,
+ status,
+ curses.A_BOLD,
)
- if next_chunk is None:
- yield postgres_store._simple_insert(
- table="port_from_sqlite3",
- values={"table_name": table, "rowid": 0}
+ max_len = max([len(t) for t in self.tables.keys()])
+
+ left_margin = 5
+ middle_space = 1
+
+ items = self.tables.items()
+ items.sort(
+ key=lambda i: (i[1]["perc"], i[0]),
+ )
+
+ for i, (table, data) in enumerate(items):
+ if i + 2 >= rows:
+ break
+
+ perc = data["perc"]
+
+ color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
+
+ self.stdscr.addstr(
+ i+2, left_margin + max_len - len(table),
+ table,
+ curses.A_BOLD | color,
)
- next_chunk = 0
- else:
- def delete_all(txn):
- txn.execute(
- "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
- (table,)
+ size = 20
+
+ progress = "[%s%s]" % (
+ "#" * int(perc*size/100),
+ " " * (size - int(perc*size/100)),
)
- txn.execute("TRUNCATE %s CASCADE" % (table,))
- postgres_store._simple_insert_txn(
- txn,
- table="port_from_sqlite3",
- values={"table_name": table, "rowid": 0}
+
+ self.stdscr.addstr(
+ i+2, left_margin + max_len + middle_space,
+ "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
- yield postgres_store.runInteraction(
- "delete_non_append_only", delete_all
- )
+ if self.finished:
+ self.stdscr.addstr(
+ self.rows-1, 0 ,
+ "Press any key to exit...",
+ )
- next_chunk = 0
+ self.stdscr.refresh()
+ self.last_update = time.time()
- logger.info("next_chunk for %s: %d", table, next_chunk)
+ def done(self):
+ self.finished = True
+ self.render(True)
+ self.stdscr.getch()
- N = 5000
+ def on_prepare_sqlite(self):
+ self.stdscr.clear()
+ self.stdscr.addstr(
+ 0, 0,
+ "Preparing SQLite database...",
+ curses.A_BOLD,
+ )
+ self.stdscr.refresh()
+
+ def on_prepare_postgres(self):
+ self.stdscr.clear()
+ self.stdscr.addstr(
+ 0, 0,
+ "Preparing PostgreSQL database...",
+ curses.A_BOLD,
+ )
+ self.stdscr.refresh()
+
+ def fetching_tables(self):
+ self.stdscr.clear()
+ self.stdscr.addstr(
+ 0, 0,
+ "Fetching tables...",
+ curses.A_BOLD,
+ )
+ self.stdscr.refresh()
+
+ def preparing_tables(self):
+ self.stdscr.clear()
+ self.stdscr.addstr(
+ 0, 0,
+ "Preparing tables...",
+ curses.A_BOLD,
+ )
+ self.stdscr.refresh()
- select = "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
- uni_col_names = UNICODE_COLUMNS.get(table, [])
- bool_col_names = BOOLEAN_COLUMNS.get(table, [])
- bin_col_names = BINARY_COLUMNS.get(table, [])
+class TerminalProgress(Progress):
+ """Just prints progress to the terminal
+ """
+ def update(self, table, num_done):
+ super(TerminalProgress, self).update(table, num_done)
- while True:
- def r(txn):
- txn.execute(select, (next_chunk, N,))
- rows = txn.fetchall()
- headers = [column[0] for column in txn.description]
+ data = self.tables[table]
- return headers, rows
+ print "%s: %d%% (%d/%d)" % (
+ table, data["perc"],
+ data["num_done"], data["total"],
+ )
- headers, rows = yield sqlite_store.runInteraction("select", r)
+ def on_prepare_sqlite(self):
+ print "Preparing SQLite database..."
- logger.info("Got %d rows for %s", len(rows), table)
+ def on_prepare_postgres(self):
+ print "Preparing PostgreSQL database..."
- if rows:
- uni_cols = [i for i, h in enumerate(headers) if h in uni_col_names]
- bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
- bin_cols = [i for i, h in enumerate(headers) if h in bin_col_names]
- next_chunk = rows[-1][0] + 1
+ def fetching_tables(self):
+ print "Fetching tables..."
- def conv(j, col):
- if j in uni_cols:
- col = sqlite_store.database_engine.load_unicode(col)
- if j in bool_cols:
- return bool(col)
+ def preparing_tables(self):
+ print "Preparing tables..."
- if j in bin_cols:
- if isinstance(col, types.UnicodeType):
- col = buffer(col.encode("utf8"))
- return col
+class Porter(object):
+ def __init__(self, **kwargs):
+ self.__dict__.update(kwargs)
- for i, row in enumerate(rows):
- rows[i] = tuple(
- postgres_store.database_engine.encode_parameter(
- conv(j, col)
- )
- for j, col in enumerate(row)
- if j > 0
- )
+ @defer.inlineCallbacks
+ def handle_table(self, table):
+ if table in APPEND_ONLY_TABLES:
+ # It's safe to just carry on inserting.
+ next_chunk = yield self.postgres_store._simple_select_one_onecol(
+ table="port_from_sqlite3",
+ keyvalues={"table_name": table},
+ retcol="rowid",
+ allow_none=True,
+ )
- def ins(txn):
- postgres_store.insert_many_txn(txn, table, headers[1:], rows)
+ if next_chunk is None:
+ yield self.postgres_store._simple_insert(
+ table="port_from_sqlite3",
+ values={"table_name": table, "rowid": 1}
+ )
- postgres_store._simple_update_one_txn(
+ next_chunk = 1
+ else:
+ def delete_all(txn):
+ txn.execute(
+ "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
+ (table,)
+ )
+ txn.execute("TRUNCATE %s CASCADE" % (table,))
+ self.postgres_store._simple_insert_txn(
txn,
table="port_from_sqlite3",
- keyvalues={"table_name": table},
- updatevalues={"rowid": next_chunk},
+ values={"table_name": table, "rowid": 0}
)
+ yield self.postgres_store.execute(delete_all)
- yield postgres_store.runInteraction("insert_many", ins)
- else:
+ next_chunk = 1
+
+ def get_table_size(txn):
+ txn.execute("SELECT count(*) FROM %s" % (table,))
+ size, = txn.fetchone()
+ return int(size)
+
+ table_size = yield self.sqlite_store.execute(get_table_size)
+ postgres_size = yield self.postgres_store.execute(get_table_size)
+
+ if not table_size:
return
+ self.progress.add_table(table, postgres_size, table_size)
-def setup_db(db_config, database_engine):
- db_conn = database_engine.module.connect(
- **{
- k: v for k, v in db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
+ select = (
+ "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
+ % (table,)
+ )
- database_engine.prepare_database(db_conn)
+ bool_col_names = BOOLEAN_COLUMNS.get(table, [])
- db_conn.commit()
+ while True:
+ def r(txn):
+ txn.execute(select, (next_chunk, self.batch_size,))
+ rows = txn.fetchall()
+ headers = [column[0] for column in txn.description]
+ return headers, rows
-@defer.inlineCallbacks
-def main(sqlite_config, postgress_config):
- try:
- sqlite_db_pool = adbapi.ConnectionPool(
- sqlite_config["name"],
- **sqlite_config["args"]
- )
+ headers, rows = yield self.sqlite_store.runInteraction("select", r)
- postgres_db_pool = adbapi.ConnectionPool(
- postgress_config["name"],
- **postgress_config["args"]
- )
+ if rows:
+ bool_cols = [
+ i for i, h in enumerate(headers) if h in bool_col_names
+ ]
+ next_chunk = rows[-1][0] + 1
+
+ def conv(j, col):
+ if j in bool_cols:
+ return bool(col)
+ return col
+
+ for i, row in enumerate(rows):
+ rows[i] = tuple(
+ self.postgres_store.database_engine.encode_parameter(
+ conv(j, col)
+ )
+ for j, col in enumerate(row)
+ if j > 0
+ )
- sqlite_engine = create_engine("sqlite3")
- postgres_engine = create_engine("psycopg2")
+ def insert(txn):
+ self.postgres_store.insert_many_txn(
+ txn, table, headers[1:], rows
+ )
- sqlite_store = Store(sqlite_db_pool, sqlite_engine)
- postgres_store = Store(postgres_db_pool, postgres_engine)
+ self.postgres_store._simple_update_one_txn(
+ txn,
+ table="port_from_sqlite3",
+ keyvalues={"table_name": table},
+ updatevalues={"rowid": next_chunk},
+ )
- # Step 1. Set up databases.
- logger.info("Preparing sqlite database...")
- setup_db(sqlite_config, sqlite_engine)
+ yield self.postgres_store.execute(insert)
- logger.info("Preparing postgres database...")
- setup_db(postgress_config, postgres_engine)
+ postgres_size += len(rows)
- # Step 2. Get tables.
- logger.info("Fetching tables...")
- tables = yield sqlite_store._simple_select_onecol(
- table="sqlite_master",
- keyvalues={
- "type": "table",
- },
- retcol="name",
+ self.progress.update(table, postgres_size)
+ else:
+ return
+
+ def setup_db(self, db_config, database_engine):
+ db_conn = database_engine.module.connect(
+ **{
+ k: v for k, v in db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
)
- logger.info("Found %d tables", len(tables))
+ database_engine.prepare_database(db_conn)
- def create_port_table(txn):
- txn.execute(
- "CREATE TABLE port_from_sqlite3 ("
- " table_name varchar(100) NOT NULL UNIQUE,"
- " rowid bigint NOT NULL"
- ")"
- )
+ db_conn.commit()
+ @defer.inlineCallbacks
+ def run(self):
try:
- yield postgres_store.runInteraction(
- "create_port_table", create_port_table
+ sqlite_db_pool = adbapi.ConnectionPool(
+ self.sqlite_config["name"],
+ **self.sqlite_config["args"]
)
- except Exception as e:
- logger.info("Failed to create port table: %s", e)
-
- # Process tables.
- yield defer.gatherResults(
- [
- handle_table(table, sqlite_store, postgres_store)
- for table in tables
- if table not in ["schema_version", "applied_schema_deltas"]
- and not table.startswith("sqlite_")
- ],
- consumeErrors=True,
- )
- except:
- logger.exception("")
- finally:
- reactor.stop()
+ postgres_db_pool = adbapi.ConnectionPool(
+ self.postgres_config["name"],
+ **self.postgres_config["args"]
+ )
+
+ sqlite_engine = create_engine("sqlite3")
+ postgres_engine = create_engine("psycopg2")
+
+ self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
+ self.postgres_store = Store(postgres_db_pool, postgres_engine)
+
+ # Step 1. Set up databases.
+ self.progress.on_prepare_sqlite()
+ self.setup_db(sqlite_config, sqlite_engine)
+
+ self.progress.on_prepare_postgres()
+ self.setup_db(postgres_config, postgres_engine)
+
+ # Step 2. Get tables.
+ self.progress.fetching_tables()
+ sqlite_tables = yield self.sqlite_store._simple_select_onecol(
+ table="sqlite_master",
+ keyvalues={
+ "type": "table",
+ },
+ retcol="name",
+ )
+
+ postgres_tables = yield self.postgres_store._simple_select_onecol(
+ table="information_schema.tables",
+ keyvalues={
+ "table_schema": "public",
+ },
+ retcol="distinct table_name",
+ )
+
+ tables = set(sqlite_tables) & set(postgres_tables)
+
+ self.progress.preparing_tables()
+
+ logger.info("Found %d tables", len(tables))
+
+ def create_port_table(txn):
+ txn.execute(
+ "CREATE TABLE port_from_sqlite3 ("
+ " table_name varchar(100) NOT NULL UNIQUE,"
+ " rowid bigint NOT NULL"
+ ")"
+ )
+
+ try:
+ yield self.postgres_store.runInteraction(
+ "create_port_table", create_port_table
+ )
+ except Exception as e:
+ logger.info("Failed to create port table: %s", e)
+
+ # Process tables.
+ yield defer.gatherResults(
+ [
+ self.handle_table(table)
+ for table in tables
+ if table not in ["schema_version", "applied_schema_deltas"]
+ and not table.startswith("sqlite_")
+ ],
+ consumeErrors=True,
+ )
+
+ self.progress.done()
+ except:
+ global end_error_exec_info
+ end_error_exec_info = sys.exc_info()
+ logger.exception("")
+ finally:
+ reactor.stop()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
+ parser.add_argument("-v", action='store_true')
+ parser.add_argument("--curses", action='store_true')
parser.add_argument("--sqlite-database")
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'),
)
+ parser.add_argument("--batch-size", type=int, default=1000)
+
args = parser.parse_args()
- logging.basicConfig(level=logging.INFO)
+
+
+ logging_config = {
+ "level": logging.DEBUG if args.v else logging.INFO,
+ "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
+ }
+
+ if args.curses:
+ logging_config["filename"] = "port-synapse.log"
+
+ logging.basicConfig(**logging_config)
sqlite_config = {
"name": "sqlite3",
@@ -364,10 +584,28 @@ if __name__ == "__main__":
postgres_config = yaml.safe_load(args.postgres_config)
- reactor.callWhenRunning(
- main,
- sqlite_config=sqlite_config,
- postgres_config=postgres_config,
- )
+ def start(stdscr=None):
+ if stdscr:
+ progress = CursesProgress(stdscr)
+ else:
+ progress = TerminalProgress()
+
+ porter = Porter(
+ sqlite_config=sqlite_config,
+ postgres_config=postgres_config,
+ progress=progress,
+ batch_size=args.batch_size,
+ )
+
+ reactor.callWhenRunning(porter.run)
+
+ reactor.run()
+
+ if args.curses:
+ curses.wrapper(start)
+ else:
+ start()
- reactor.run()
+ if end_error_exec_info:
+ exc_type, exc_value, exc_traceback = end_error_exec_info
+ traceback.print_exception(exc_type, exc_value, exc_traceback)
|