diff options
author | Shay <hillerys@element.io> | 2022-03-23 10:23:05 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-23 10:23:05 -0700 |
commit | e78d4f61fc881851ab35e9a889239a61cf9805e5 (patch) | |
tree | 3c579083dc15936dac5dec0367d8b21274698b53 /synapse/storage | |
parent | Use psycopg2 type stubs (#12269) (diff) | |
download | synapse-e78d4f61fc881851ab35e9a889239a61cf9805e5.tar.xz |
Refuse to start if DB has an unsafe locale (#12262)
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/engines/postgres.py | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 808342fafb..e8d29e2870 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine): self.default_isolation_level = ( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + self.config = database_config @property def single_threaded(self) -> bool: return False + def get_db_locale(self, txn): + txn.execute( + "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" + ) + collation, ctype = txn.fetchone() + return collation, ctype + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them # together. For example, version 8.1.5 will be returned as 80105 self._version = db_conn.server_version + allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? if not allow_outdated_version and self._version < 100000: @@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine): "See docs/postgres.md for more information." % (rows[0][0],) ) - txn.execute( - "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" - ) - collation, ctype = txn.fetchone() + collation, ctype = self.get_db_locale(txn) if collation != "C": logger.warning( - "Database has incorrect collation of %r. Should be 'C'\n" - "See docs/postgres.md for more information.", + "Database has incorrect collation of %r. Should be 'C'", collation, ) + if not allow_unsafe_locale: + raise IncorrectDatabaseSetup( + "Database has incorrect collation of %r. Should be 'C'\n" + "See docs/postgres.md for more information. You can override this check by" + "setting 'allow_unsafe_locale' to true in the database config.", + collation, + ) if ctype != "C": - logger.warning( - "Database has incorrect ctype of %r. Should be 'C'\n" - "See docs/postgres.md for more information.", - ctype, - ) + if not allow_unsafe_locale: + logger.warning( + "Database has incorrect ctype of %r. Should be 'C'", + ctype, + ) + raise IncorrectDatabaseSetup( + "Database has incorrect ctype of %r. Should be 'C'\n" + "See docs/postgres.md for more information. You can override this check by" + "setting 'allow_unsafe_locale' to true in the database config.", + ctype, + ) def check_new_database(self, txn): """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ - txn.execute( - "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" - ) - collation, ctype = txn.fetchone() + collation, ctype = self.get_db_locale(txn) errors = [] |