summary refs log tree commit diff
path: root/synapse/storage/engines/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/engines/postgres.py')
-rw-r--r--synapse/storage/engines/postgres.py108
1 files changed, 88 insertions, 20 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 289b6bc281..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,32 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import IncorrectDatabaseSetup
+import logging
 
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 
-class PostgresEngine(object):
-    single_threaded = False
+logger = logging.getLogger(__name__)
 
+
+class PostgresEngine(BaseDatabaseEngine):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
         self.module.extensions.register_type(self.module.extensions.UNICODE)
-        self.synchronous_commit = database_config.get("synchronous_commit", True)
-        self._version = None  # unknown as yet
 
-    def check_database(self, txn):
-        txn.execute("SHOW SERVER_ENCODING")
-        rows = txn.fetchall()
-        if rows and rows[0][0] != "UTF8":
-            raise IncorrectDatabaseSetup(
-                "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
-                "See docs/postgres.rst for more information." % (rows[0][0],)
-            )
+        # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+        # actually want to use bytes than wrap it in `bytearray`.
+        def _disable_bytes_adapter(_):
+            raise Exception("Passing bytes to DB is disabled.")
 
-    def convert_param_style(self, sql):
-        return sql.replace("?", "%s")
+        self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
+        self.synchronous_commit = database_config.get("synchronous_commit", True)
+        self._version = None  # unknown as yet
 
-    def on_new_connection(self, db_conn):
+    @property
+    def single_threaded(self) -> bool:
+        return False
 
+    def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
@@ -46,9 +46,64 @@ class PostgresEngine(object):
         self._version = db_conn.server_version
 
         # Are we on a supported PostgreSQL version?
-        if self._version < 90500:
+        if not allow_outdated_version and self._version < 90500:
             raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
 
+        with db_conn.cursor() as txn:
+            txn.execute("SHOW SERVER_ENCODING")
+            rows = txn.fetchall()
+            if rows and rows[0][0] != "UTF8":
+                raise IncorrectDatabaseSetup(
+                    "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+                    "See docs/postgres.md for more information." % (rows[0][0],)
+                )
+
+            txn.execute(
+                "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+            )
+            collation, ctype = txn.fetchone()
+            if collation != "C":
+                logger.warning(
+                    "Database has incorrect collation of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    collation,
+                )
+
+            if ctype != "C":
+                logger.warning(
+                    "Database has incorrect ctype of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    ctype,
+                )
+
+    def check_new_database(self, txn):
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+
+        txn.execute(
+            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+        )
+        collation, ctype = txn.fetchone()
+
+        errors = []
+
+        if collation != "C":
+            errors.append("    - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
+
+        if ctype != "C":
+            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+
+        if errors:
+            raise IncorrectDatabaseSetup(
+                "Database is incorrectly configured:\n\n%s\n\n"
+                "See docs/postgres.md for more information." % ("\n".join(errors))
+            )
+
+    def convert_param_style(self, sql):
+        return sql.replace("?", "%s")
+
+    def on_new_connection(self, db_conn):
         db_conn.set_isolation_level(
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
@@ -72,6 +127,19 @@ class PostgresEngine(object):
         """
         return True
 
+    @property
+    def supports_tuple_comparison(self):
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+        """
+        return True
+
+    @property
+    def supports_using_any_list(self):
+        """Do we support using `a = ANY(?)` and passing a list
+        """
+        return True
+
     def is_deadlock(self, error):
         if isinstance(error, self.module.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
@@ -99,8 +167,8 @@ class PostgresEngine(object):
         Returns:
             string
         """
-        # note that this is a bit of a hack because it relies on on_new_connection
-        # having been called at least once. Still, that should be a safe bet here.
+        # note that this is a bit of a hack because it relies on check_database
+        # having been called. Still, that should be a safe bet here.
         numver = self._version
         assert numver is not None