summary refs log tree commit diff
path: root/synapse/storage/engines/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/engines/postgres.py')
-rw-r--r--synapse/storage/engines/postgres.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index c84cb452b0..a077345960 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,8 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+
 from ._base import IncorrectDatabaseSetup
 
+logger = logging.getLogger(__name__)
+
 
 class PostgresEngine(object):
     single_threaded = False
@@ -52,6 +56,44 @@ class PostgresEngine(object):
                     "See docs/postgres.rst for more information." % (rows[0][0],)
                 )
 
+            txn.execute(
+                "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+            )
+            collation, ctype = txn.fetchone()
+            if collation != "C":
+                logger.warning(
+                    "Database has incorrect collation of %r. Should be 'C'", collation
+                )
+
+            if ctype != "C":
+                logger.warning(
+                    "Database has incorrect ctype of %r. Should be 'C'", ctype
+                )
+
+    def check_new_database(self, txn):
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+
+        txn.execute(
+            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+        )
+        collation, ctype = txn.fetchone()
+
+        errors = []
+
+        if collation != "C":
+            errors.append("    - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
+
+        if ctype != "C":
+            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+
+        if errors:
+            raise IncorrectDatabaseSetup(
+                "Database is incorrectly configured:\n\n%s\n\n"
+                "See docs/postgres.md for more information." % ("\n".join(errors))
+            )
+
     def convert_param_style(self, sql):
         return sql.replace("?", "%s")