diff options
-rw-r--r-- | docs/password_auth_providers.rst | 12 | ||||
-rwxr-xr-x | scripts/synapse_port_db | 26 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 70 | ||||
-rw-r--r-- | synapse/storage/schema/schema_version.sql | 7 | ||||
-rw-r--r-- | tests/utils.py | 1 |
5 files changed, 109 insertions, 7 deletions
diff --git a/docs/password_auth_providers.rst b/docs/password_auth_providers.rst index 3da1a67844..ca05a76617 100644 --- a/docs/password_auth_providers.rst +++ b/docs/password_auth_providers.rst @@ -37,3 +37,15 @@ Password auth provider classes must provide the following methods: The method should return a Twisted ``Deferred`` object, which resolves to ``True`` if authentication is successful, and ``False`` if not. + +Optional methods +---------------- + +Password provider classes may optionally provide the following methods. + +*class* ``SomeProvider.get_db_schema_files()`` + + This method, if implemented, should return an Iterable of ``(name, + stream)`` pairs of database schema files. Each file is applied in turn at + initialisation, and a record is then made in the database so that it is + not re-applied on the next start. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index d6d8ee50cb..3a8972efc3 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -320,7 +320,7 @@ class Porter(object): backward_chunk = min(row[0] for row in brows) - 1 rows = frows + brows - self._convert_rows(table, headers, rows) + rows = self._convert_rows(table, headers, rows) def insert(txn): self.postgres_store.insert_many_txn( @@ -556,17 +556,29 @@ class Porter(object): i for i, h in enumerate(headers) if h in bool_col_names ] + class BadValueException(Exception): + pass + def conv(j, col): if j in bool_cols: return bool(col) + elif isinstance(col, basestring) and "\0" in col: + logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) + raise BadValueException(); return col + outrows = [] for i, row in enumerate(rows): - rows[i] = tuple( - conv(j, col) - for j, col in enumerate(row) - if j > 0 - ) + try: + outrows.append(tuple( + conv(j, col) + for j, col in enumerate(row) + if j > 0 + )) + except BadValueException: + pass + + return outrows @defer.inlineCallbacks def _setup_sent_transactions(self): @@ -594,7 +606,7 @@ class Porter(object): "select", r, ) - self._convert_rows("sent_transactions", headers, rows) + rows = self._convert_rows("sent_transactions", headers, rows) inserted_rows = len(rows) if inserted_rows: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index a4e08e6757..d1691bbac2 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() except Exception: @@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/tests/utils.py b/tests/utils.py index d2ebce4b2e..ed8a7360f5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object): ) self.config = Mock() + self.config.password_providers = [] self.config.database_config = {"name": "sqlite3"} def prepare(self): |