diff options
author | Richard van der Hoff <richard@matrix.org> | 2020-01-09 17:21:30 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2020-01-09 18:05:50 +0000 |
commit | e97d1cf0014668b9d4883d4175b783088444b24b (patch) | |
tree | f490b673ca4de0c7bd3ed53081f7f92a84394de2 /synapse/storage/engines | |
parent | Allow admin users to create or modify users without a shared secret (#6495) (diff) | |
download | synapse-e97d1cf0014668b9d4883d4175b783088444b24b.tar.xz |
Modify check_database to take a connection rather than a cursor
We might not need the cursor at all.
Diffstat (limited to 'synapse/storage/engines')
-rw-r--r-- | synapse/storage/engines/postgres.py | 17 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite.py | 2 |
2 files changed, 10 insertions, 9 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index b7c4eda338..ba19785fd7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -32,14 +32,15 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet - def check_database(self, txn): - txn.execute("SHOW SERVER_ENCODING") - rows = txn.fetchall() - if rows and rows[0][0] != "UTF8": - raise IncorrectDatabaseSetup( - "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." % (rows[0][0],) - ) + def check_database(self, db_conn): + with db_conn.cursor() as txn: + txn.execute("SHOW SERVER_ENCODING") + rows = txn.fetchall() + if rows and rows[0][0] != "UTF8": + raise IncorrectDatabaseSetup( + "Database has incorrect encoding: '%s' instead of 'UTF8'\n" + "See docs/postgres.rst for more information." % (rows[0][0],) + ) def convert_param_style(self, sql): return sql.replace("?", "%s") diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index df039a072d..3b3c13360b 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -53,7 +53,7 @@ class Sqlite3Engine(object): """ return False - def check_database(self, txn): + def check_database(self, db_conn): pass def convert_param_style(self, sql): |