diff options
Diffstat (limited to 'scripts')
-rwxr-xr-x | scripts/export_signing_key | 94 | ||||
-rwxr-xr-x | scripts/generate_log_config | 43 | ||||
-rwxr-xr-x | scripts/hash_password | 2 | ||||
-rwxr-xr-x | scripts/move_remote_media_to_new_store.py | 2 | ||||
-rwxr-xr-x | scripts/synapse_port_db | 447 |
5 files changed, 410 insertions, 178 deletions
diff --git a/scripts/export_signing_key b/scripts/export_signing_key new file mode 100755 index 0000000000..8aec9d802b --- /dev/null +++ b/scripts/export_signing_key @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys +import time +from typing import Optional + +import nacl.signing +from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys + + +def exit(status: int = 0, message: Optional[str] = None): + if message: + print(message, file=sys.stderr) + sys.exit(status) + + +def format_plain(public_key: nacl.signing.VerifyKey): + print( + "%s:%s %s" + % (public_key.alg, public_key.version, encode_verify_key_base64(public_key),) + ) + + +def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): + print( + ' "%s:%s": { key: "%s", expired_ts: %i }' + % ( + public_key.alg, + public_key.version, + encode_verify_key_base64(public_key), + expiry_ts, + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read", + ) + + parser.add_argument( + "-x", + action="store_true", + dest="for_config", + help="format the output for inclusion in the old_signing_keys config setting", + ) + + parser.add_argument( + "--expiry-ts", + type=int, + default=int(time.time() * 1000) + 6*3600000, + help=( + "The expiry time to use for -x, in milliseconds since 1970. The default " + "is (now+6h)." + ), + ) + + args = parser.parse_args() + + formatter = ( + (lambda k: format_for_config(k, args.expiry_ts)) + if args.for_config + else format_plain + ) + + keys = [] + for file in args.key_file: + try: + res = read_signing_keys(file) + except Exception as e: + exit( + status=1, + message="Error reading key from file %s: %s %s" + % (file.name, type(e), e), + ) + res = [] + for key in res: + formatter(get_verify_key(key)) diff --git a/scripts/generate_log_config b/scripts/generate_log_config new file mode 100755 index 0000000000..b6957f48a3 --- /dev/null +++ b/scripts/generate_log_config @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from synapse.config.logger import DEFAULT_LOG_CONFIG + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-o", + "--output-file", + type=argparse.FileType("w"), + default=sys.stdout, + help="File to write the configuration to. Default: stdout", + ) + + parser.add_argument( + "-f", + "--log-file", + type=str, + default="/var/log/matrix-synapse/homeserver.log", + help="name of the log file", + ) + + args = parser.parse_args() + args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) diff --git a/scripts/hash_password b/scripts/hash_password index a1eb0769da..a30767f758 100755 --- a/scripts/hash_password +++ b/scripts/hash_password @@ -52,7 +52,7 @@ if __name__ == "__main__": if "config" in args and args.config: config = yaml.safe_load(args.config) bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds) - password_config = config.get("password_config", {}) + password_config = config.get("password_config", None) or {} password_pepper = password_config.get("pepper", password_pepper) password = args.password diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 12747c6024..b5b63933ab 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths): # check that the original exists original_file = src_paths.remote_media_filepath(origin_server, file_id) if not os.path.exists(original_file): - logger.warn( + logger.warning( "Original for %s/%s (%s) does not exist", origin_server, file_id, diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index b6ba19c776..9a0fbc61d8 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,12 +27,44 @@ from six import string_types import yaml -from twisted.enterprise import adbapi from twisted.internet import defer, reactor -from synapse.storage._base import LoggingTransaction, SQLBaseStore +import synapse +from synapse.config.database import DatabaseConnectionConfig +from synapse.config.homeserver import HomeServerConfig +from synapse.logging.context import ( + LoggingContext, + make_deferred_yieldable, + run_in_background, +) +from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.data_stores.main.deviceinbox import ( + DeviceInboxBackgroundUpdateStore, +) +from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore +from synapse.storage.data_stores.main.events_bg_updates import ( + EventsBackgroundUpdatesStore, +) +from synapse.storage.data_stores.main.media_repository import ( + MediaRepositoryBackgroundUpdateStore, +) +from synapse.storage.data_stores.main.registration import ( + RegistrationBackgroundUpdateStore, +) +from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore +from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore +from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore +from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore +from synapse.storage.data_stores.main.stats import StatsStore +from synapse.storage.data_stores.main.user_directory import ( + UserDirectoryBackgroundUpdateStore, +) +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +from synapse.util import Clock +from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse_port_db") @@ -43,6 +76,7 @@ BOOLEAN_COLUMNS = { "presence_list": ["accepted"], "presence_stream": ["currently_active"], "public_room_list_stream": ["visibility"], + "devices": ["hidden"], "device_lists_outbound_pokes": ["sent"], "users_who_share_rooms": ["share_private"], "groups": ["is_public"], @@ -55,6 +89,8 @@ BOOLEAN_COLUMNS = { "local_group_membership": ["is_publicised", "is_admin"], "e2e_room_keys": ["is_verified"], "account_validity": ["email_sent"], + "redactions": ["have_censored"], + "room_stats_state": ["is_federatable"], } @@ -86,79 +122,47 @@ APPEND_ONLY_TABLES = [ "presence_stream", "push_rules_stream", "ex_outlier_stream", - "cache_invalidation_stream", + "cache_invalidation_stream_by_instance", "public_room_list_stream", "state_group_edges", "stream_ordering_to_exterm", ] +# Error returned by the run function. Used at the top-level part of the script to +# handle errors and return codes. +end_error = None +# The exec_info for the error, if any. If error is defined but not exec_info the script +# will show only the error message without the stacktrace, if exec_info is defined but +# not the error then the script will show nothing outside of what's printed in the run +# function. If both are defined, the script will print both the error and the stacktrace. end_error_exec_info = None -class Store(object): - """This object is used to pull out some of the convenience API from the - Storage layer. - - *All* database interactions should go through this object. - """ - - def __init__(self, db_pool, engine): - self.db_pool = db_pool - self.database_engine = engine - - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] - _simple_insert = SQLBaseStore.__dict__["_simple_insert"] - - _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] - _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] - _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"] - _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"] - _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] - _simple_select_one_onecol_txn = SQLBaseStore.__dict__[ - "_simple_select_one_onecol_txn" - ] - - _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] - _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] - _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"] - - def runInteraction(self, desc, func, *args, **kwargs): - def r(conn): - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - return func( - LoggingTransaction(txn, desc, self.database_engine, [], []), - *args, - **kwargs - ) - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) - if i < N: - i += 1 - conn.rollback() - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", desc, e) - raise - - return self.db_pool.runWithConnection(r) - +class Store( + ClientIpBackgroundUpdateStore, + DeviceInboxBackgroundUpdateStore, + DeviceBackgroundUpdateStore, + EventsBackgroundUpdatesStore, + MediaRepositoryBackgroundUpdateStore, + RegistrationBackgroundUpdateStore, + RoomBackgroundUpdateStore, + RoomMemberBackgroundUpdateStore, + SearchBackgroundUpdateStore, + StateBackgroundUpdateStore, + MainStateBackgroundUpdateStore, + UserDirectoryBackgroundUpdateStore, + StatsStore, +): def execute(self, f, *args, **kwargs): - return self.runInteraction(f.__name__, f, *args, **kwargs) + return self.db.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.runInteraction("execute_sql", r) + return self.db.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -173,16 +177,37 @@ class Store(object): logger.exception("Failed to insert: %s", table) raise + def set_room_is_public(self, room_id, is_public): + raise Exception( + "Attempt to set room_is_public during port_db: database not empty?" + ) + + +class MockHomeserver: + def __init__(self, config): + self.clock = Clock(reactor) + self.config = config + self.hostname = config.server_name + self.version_string = "Synapse/" + get_version_string(synapse) + + def get_clock(self): + return self.clock + + def get_reactor(self): + return reactor + + def get_instance_name(self): + return "master" + class Porter(object): def __init__(self, **kwargs): self.__dict__.update(kwargs) - @defer.inlineCallbacks - def setup_table(self, table): + async def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = await self.postgres_store.db.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -192,12 +217,14 @@ class Porter(object): total_to_port = None if row is None: if table == "sent_transactions": - forward_chunk, already_ported, total_to_port = ( - yield self._setup_sent_transactions() - ) + ( + forward_chunk, + already_ported, + total_to_port, + ) = await self._setup_sent_transactions() backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -214,7 +241,7 @@ class Porter(object): backward_chunk = row["backward_rowid"] if total_to_port is None: - already_ported, total_to_port = yield self._get_total_count_to_port( + already_ported, total_to_port = await self._get_total_count_to_port( table, forward_chunk, backward_chunk ) else: @@ -225,9 +252,9 @@ class Porter(object): ) txn.execute("TRUNCATE %s CASCADE" % (table,)) - yield self.postgres_store.execute(delete_all) + await self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -235,16 +262,13 @@ class Porter(object): forward_chunk = 1 backward_chunk = 0 - already_ported, total_to_port = yield self._get_total_count_to_port( + already_ported, total_to_port = await self._get_total_count_to_port( table, forward_chunk, backward_chunk ) - defer.returnValue( - (table, already_ported, total_to_port, forward_chunk, backward_chunk) - ) + return table, already_ported, total_to_port, forward_chunk, backward_chunk - @defer.inlineCallbacks - def handle_table( + async def handle_table( self, table, postgres_size, table_size, forward_chunk, backward_chunk ): logger.info( @@ -262,7 +286,7 @@ class Porter(object): self.progress.add_table(table, postgres_size, table_size) if table == "event_search": - yield self.handle_search_table( + await self.handle_search_table( postgres_size, table_size, forward_chunk, backward_chunk ) return @@ -281,7 +305,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -322,7 +346,9 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) + headers, frows, brows = await self.sqlite_store.db.runInteraction( + "select", r + ) if frows or brows: if frows: @@ -336,7 +362,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -346,7 +372,7 @@ class Porter(object): }, ) - yield self.postgres_store.execute(insert) + await self.postgres_store.execute(insert) postgres_size += len(rows) @@ -354,8 +380,7 @@ class Porter(object): else: return - @defer.inlineCallbacks - def handle_search_table( + async def handle_search_table( self, postgres_size, table_size, forward_chunk, backward_chunk ): select = ( @@ -375,7 +400,7 @@ class Porter(object): return headers, rows - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = await self.sqlite_store.db.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -392,8 +417,8 @@ class Porter(object): rows_dict = [] for row in rows: d = dict(zip(headers, row)) - if "\0" in d['value']: - logger.warn('dropping search row %s', d) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) else: rows_dict.append(d) @@ -413,7 +438,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -423,7 +448,7 @@ class Porter(object): }, ) - yield self.postgres_store.execute(insert) + await self.postgres_store.execute(insert) postgres_size += len(rows) @@ -432,44 +457,85 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) + def build_db_store( + self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False, + ): + """Builds and returns a database store using the provided configuration. - prepare_database(db_conn, database_engine, config=None) + Args: + db_config: The database configuration + allow_outdated_version: True to suppress errors about the database server + version being too old to run a complete synapse - db_conn.commit() + Returns: + The built Store object. + """ + self.progress.set_state("Preparing %s" % db_config.config["name"]) - @defer.inlineCallbacks - def run(self): - try: - sqlite_db_pool = adbapi.ConnectionPool( - self.sqlite_config["name"], **self.sqlite_config["args"] + engine = create_engine(db_config.config) + + hs = MockHomeserver(self.hs_config) + + with make_conn(db_config, engine) as db_conn: + engine.check_database( + db_conn, allow_outdated_version=allow_outdated_version ) + prepare_database(db_conn, engine, config=self.hs_config) + store = Store(Database(hs, db_config, engine), db_conn, hs) + db_conn.commit() + + return store - postgres_db_pool = adbapi.ConnectionPool( - self.postgres_config["name"], **self.postgres_config["args"] + async def run_background_updates_on_postgres(self): + # Manually apply all background updates on the PostgreSQL database. + postgres_ready = ( + await self.postgres_store.db.updates.has_completed_background_updates() + ) + + if not postgres_ready: + # Only say that we're running background updates when there are background + # updates to run. + self.progress.set_state("Running background updates on PostgreSQL") + + while not postgres_ready: + await self.postgres_store.db.updates.do_next_background_update(100) + postgres_ready = await ( + self.postgres_store.db.updates.has_completed_background_updates() ) - sqlite_engine = create_engine(sqlite_config) - postgres_engine = create_engine(postgres_config) + async def run(self): + """Ports the SQLite database to a PostgreSQL database. - self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) - self.postgres_store = Store(postgres_db_pool, postgres_engine) + When a fatal error is met, its message is assigned to the global "end_error" + variable. When this error comes with a stacktrace, its exec_info is assigned to + the global "end_error_exec_info" variable. + """ + global end_error + + try: + # we allow people to port away from outdated versions of sqlite. + self.sqlite_store = self.build_db_store( + DatabaseConnectionConfig("master-sqlite", self.sqlite_config), + allow_outdated_version=True, + ) - yield self.postgres_store.execute(postgres_engine.check_database) + # Check if all background updates are done, abort if not. + updates_complete = ( + await self.sqlite_store.db.updates.has_completed_background_updates() + ) + if not updates_complete: + end_error = ( + "Pending background updates exist in the SQLite3 database." + " Please start Synapse again and wait until every update has finished" + " before running this script.\n" + ) + return - # Step 1. Set up databases. - self.progress.set_state("Preparing SQLite3") - self.setup_db(sqlite_config, sqlite_engine) + self.postgres_store = self.build_db_store( + self.hs_config.get_single_database() + ) - self.progress.set_state("Preparing PostgreSQL") - self.setup_db(postgres_config, postgres_engine) + await self.run_background_updates_on_postgres() self.progress.set_state("Creating port tables") @@ -497,22 +563,22 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction("alter_table", alter_table) + await self.postgres_store.db.runInteraction("alter_table", alter_table) except Exception: # On Error Resume Next pass - yield self.postgres_store.runInteraction( + await self.postgres_store.db.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = await self.sqlite_store.db.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = await self.postgres_store.db.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -523,28 +589,34 @@ class Porter(object): # Step 3. Figure out what still needs copying self.progress.set_state("Checking on port progress") - setup_res = yield defer.gatherResults( - [ - self.setup_table(table) - for table in tables - if table not in ["schema_version", "applied_schema_deltas"] - and not table.startswith("sqlite_") - ], - consumeErrors=True, + setup_res = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(self.setup_table, table) + for table in tables + if table not in ["schema_version", "applied_schema_deltas"] + and not table.startswith("sqlite_") + ], + consumeErrors=True, + ) ) # Step 4. Do the copying. self.progress.set_state("Copying to postgres") - yield defer.gatherResults( - [self.handle_table(*res) for res in setup_res], consumeErrors=True + await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(self.handle_table, *res) for res in setup_res], + consumeErrors=True, + ) ) # Step 5. Do final post-processing - yield self._setup_state_group_id_seq() + await self._setup_state_group_id_seq() self.progress.done() - except Exception: + except Exception as e: global end_error_exec_info + end_error = e end_error_exec_info = sys.exc_info() logger.exception("") finally: @@ -561,8 +633,10 @@ class Porter(object): def conv(j, col): if j in bool_cols: return bool(col) + if isinstance(col, bytes): + return bytearray(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn( + logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], @@ -582,8 +656,7 @@ class Porter(object): return outrows - @defer.inlineCallbacks - def _setup_sent_transactions(self): + async def _setup_sent_transactions(self): # Only save things from the last day yesterday = int(time.time() * 1000) - 86400000 @@ -600,11 +673,11 @@ class Porter(object): rows = txn.fetchall() headers = [column[0] for column in txn.description] - ts_ind = headers.index('ts') + ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = await self.sqlite_store.db.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -617,7 +690,7 @@ class Porter(object): txn, "sent_transactions", headers[1:], rows ) - yield self.postgres_store.execute(insert) + await self.postgres_store.execute(insert) else: max_inserted_rowid = 0 @@ -634,10 +707,10 @@ class Porter(object): else: return 1 - next_chunk = yield self.sqlite_store.execute(get_start_id) + next_chunk = await self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + await self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -650,57 +723,63 @@ class Porter(object): txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - size, = txn.fetchone() + (size,) = txn.fetchone() return int(size) - remaining_count = yield self.sqlite_store.execute(get_sent_table_size) + remaining_count = await self.sqlite_store.execute(get_sent_table_size) total_count = remaining_count + inserted_rows - defer.returnValue((next_chunk, inserted_rows, total_count)) + return next_chunk, inserted_rows, total_count - @defer.inlineCallbacks - def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): - frows = yield self.sqlite_store.execute_sql( + async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): + frows = await self.sqlite_store.execute_sql( "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk ) - brows = yield self.sqlite_store.execute_sql( + brows = await self.sqlite_store.execute_sql( "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk ) - defer.returnValue(frows[0][0] + brows[0][0]) + return frows[0][0] + brows[0][0] - @defer.inlineCallbacks - def _get_already_ported_count(self, table): - rows = yield self.postgres_store.execute_sql( + async def _get_already_ported_count(self, table): + rows = await self.postgres_store.execute_sql( "SELECT count(*) FROM %s" % (table,) ) - defer.returnValue(rows[0][0]) + return rows[0][0] - @defer.inlineCallbacks - def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): - remaining, done = yield defer.gatherResults( - [ - self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), - self._get_already_ported_count(table), - ], - consumeErrors=True, + async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): + remaining, done = await make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self._get_remaining_count_to_port, + table, + forward_chunk, + backward_chunk, + ), + run_in_background(self._get_already_ported_count, table), + ], + ) ) remaining = int(remaining) if remaining else 0 done = int(done) if done else 0 - defer.returnValue((done, remaining + done)) + return done, remaining + done def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0] + 1 + curr_id = txn.fetchone()[0] + if not curr_id: + return + next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) ############################################## @@ -781,7 +860,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds) + duration_str = "%02dm %02ds" % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -791,7 +870,7 @@ class CursesProgress(Progress): left = float(self.total_remaining) / self.total_processed est_remaining = (int(now) - self.start_time) * left - est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" status = "Time spent: %s (est. remaining: %s)" % ( @@ -877,7 +956,7 @@ if __name__ == "__main__": description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--sqlite-database", required=True, @@ -886,12 +965,12 @@ if __name__ == "__main__": ) parser.add_argument( "--postgres-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', help="display a curses based progress UI" + "--curses", action="store_true", help="display a curses based progress UI" ) parser.add_argument( @@ -924,18 +1003,24 @@ if __name__ == "__main__": }, } - postgres_config = yaml.safe_load(args.postgres_config) + hs_config = yaml.safe_load(args.postgres_config) + + if "database" not in hs_config: + sys.stderr.write("The configuration file must have a 'database' section.\n") + sys.exit(4) - if "database" in postgres_config: - postgres_config = postgres_config["database"] + postgres_config = hs_config["database"] if "name" not in postgres_config: - sys.stderr.write("Malformed database config: no 'name'") + sys.stderr.write("Malformed database config: no 'name'\n") sys.exit(2) if postgres_config["name"] != "psycopg2": - sys.stderr.write("Database must use 'psycopg2' connector.") + sys.stderr.write("Database must use the 'psycopg2' connector.\n") sys.exit(3) + config = HomeServerConfig() + config.parse_config_dict(hs_config, "", "") + def start(stdscr=None): if stdscr: progress = CursesProgress(stdscr) @@ -944,12 +1029,17 @@ if __name__ == "__main__": porter = Porter( sqlite_config=sqlite_config, - postgres_config=postgres_config, progress=progress, batch_size=args.batch_size, + hs_config=config, ) - reactor.callWhenRunning(porter.run) + @defer.inlineCallbacks + def run(): + with LoggingContext("synapse_port_db_run"): + yield defer.ensureDeferred(porter.run()) + + reactor.callWhenRunning(run) reactor.run() @@ -958,6 +1048,11 @@ if __name__ == "__main__": else: start() - if end_error_exec_info: - exc_type, exc_value, exc_traceback = end_error_exec_info - traceback.print_exception(exc_type, exc_value, exc_traceback) + if end_error: + if end_error_exec_info: + exc_type, exc_value, exc_traceback = end_error_exec_info + traceback.print_exception(exc_type, exc_value, exc_traceback) + + sys.stderr.write(end_error) + + sys.exit(5) |