summary refs log tree commit diff
path: root/synapse/storage/engines/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/engines/postgres.py')
-rw-r--r--synapse/storage/engines/postgres.py45
1 files changed, 30 insertions, 15 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 808342fafb..e8d29e2870 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
         self.default_isolation_level = (
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
+        self.config = database_config
 
     @property
     def single_threaded(self) -> bool:
         return False
 
+    def get_db_locale(self, txn):
+        txn.execute(
+            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+        )
+        collation, ctype = txn.fetchone()
+        return collation, ctype
+
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
         self._version = db_conn.server_version
+        allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
         if not allow_outdated_version and self._version < 100000:
@@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
                     "See docs/postgres.md for more information." % (rows[0][0],)
                 )
 
-            txn.execute(
-                "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
-            )
-            collation, ctype = txn.fetchone()
+            collation, ctype = self.get_db_locale(txn)
             if collation != "C":
                 logger.warning(
-                    "Database has incorrect collation of %r. Should be 'C'\n"
-                    "See docs/postgres.md for more information.",
+                    "Database has incorrect collation of %r. Should be 'C'",
                     collation,
                 )
+                if not allow_unsafe_locale:
+                    raise IncorrectDatabaseSetup(
+                        "Database has incorrect collation of %r. Should be 'C'\n"
+                        "See docs/postgres.md for more information. You can override this check by"
+                        "setting 'allow_unsafe_locale' to true in the database config.",
+                        collation,
+                    )
 
             if ctype != "C":
-                logger.warning(
-                    "Database has incorrect ctype of %r. Should be 'C'\n"
-                    "See docs/postgres.md for more information.",
-                    ctype,
-                )
+                if not allow_unsafe_locale:
+                    logger.warning(
+                        "Database has incorrect ctype of %r. Should be 'C'",
+                        ctype,
+                    )
+                    raise IncorrectDatabaseSetup(
+                        "Database has incorrect ctype of %r. Should be 'C'\n"
+                        "See docs/postgres.md for more information. You can override this check by"
+                        "setting 'allow_unsafe_locale' to true in the database config.",
+                        ctype,
+                    )
 
     def check_new_database(self, txn):
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
 
-        txn.execute(
-            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
-        )
-        collation, ctype = txn.fetchone()
+        collation, ctype = self.get_db_locale(txn)
 
         errors = []