summary refs log tree commit diff
path: root/synapse/storage/engines/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/engines/postgres.py')
-rw-r--r--synapse/storage/engines/postgres.py34
1 files changed, 17 insertions, 17 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index b7c4eda338..c84cb452b0 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -32,20 +32,7 @@ class PostgresEngine(object):
         self.synchronous_commit = database_config.get("synchronous_commit", True)
         self._version = None  # unknown as yet
 
-    def check_database(self, txn):
-        txn.execute("SHOW SERVER_ENCODING")
-        rows = txn.fetchall()
-        if rows and rows[0][0] != "UTF8":
-            raise IncorrectDatabaseSetup(
-                "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
-                "See docs/postgres.rst for more information." % (rows[0][0],)
-            )
-
-    def convert_param_style(self, sql):
-        return sql.replace("?", "%s")
-
-    def on_new_connection(self, db_conn):
-
+    def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
@@ -53,9 +40,22 @@ class PostgresEngine(object):
         self._version = db_conn.server_version
 
         # Are we on a supported PostgreSQL version?
-        if self._version < 90500:
+        if not allow_outdated_version and self._version < 90500:
             raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
 
+        with db_conn.cursor() as txn:
+            txn.execute("SHOW SERVER_ENCODING")
+            rows = txn.fetchall()
+            if rows and rows[0][0] != "UTF8":
+                raise IncorrectDatabaseSetup(
+                    "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+                    "See docs/postgres.rst for more information." % (rows[0][0],)
+                )
+
+    def convert_param_style(self, sql):
+        return sql.replace("?", "%s")
+
+    def on_new_connection(self, db_conn):
         db_conn.set_isolation_level(
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
@@ -119,8 +119,8 @@ class PostgresEngine(object):
         Returns:
             string
         """
-        # note that this is a bit of a hack because it relies on on_new_connection
-        # having been called at least once. Still, that should be a safe bet here.
+        # note that this is a bit of a hack because it relies on check_database
+        # having been called. Still, that should be a safe bet here.
         numver = self._version
         assert numver is not None