summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/database.py12
-rw-r--r--synapse/storage/engines/__init__.py8
-rw-r--r--synapse/storage/engines/psycopg.py57
-rw-r--r--synapse/storage/schema/main/delta/69/01as_txn_seq.py17
-rw-r--r--synapse/storage/types.py19
-rw-r--r--tests/utils.py2
6 files changed, 58 insertions, 57 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 55bcb90001..d78898370b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -56,7 +56,7 @@ from synapse.logging.context import (
 from synapse.metrics import register_threadpool
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, PsycopgEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
 from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
@@ -334,7 +334,8 @@ class LoggingTransaction:
     def fetchone(self) -> Optional[Tuple]:
         return self.txn.fetchone()
 
-    def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+    def fetchmany(self, size: int = 0) -> List[Tuple]:
+        # XXX This can also be called with no arguments.
         return self.txn.fetchmany(size=size)
 
     def fetchall(self) -> List[Tuple]:
@@ -400,6 +401,11 @@ class LoggingTransaction:
 
     def _make_sql_one_line(self, sql: str) -> str:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
+        if isinstance(self.database_engine, PsycopgEngine):
+            import psycopg.sql
+            if isinstance(sql, psycopg.sql.Composed):
+                return sql.as_string(None)
+
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
     def _do_execute(
@@ -440,7 +446,7 @@ class LoggingTransaction:
         finally:
             secs = time.time() - start
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
-            sql_query_timer.labels(sql.split()[0]).observe(secs)
+            sql_query_timer.labels(one_line_sql.split()[0]).observe(secs)
 
     def close(self) -> None:
         self.txn.close()
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index a9fdcdcef7..8038dce595 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -21,7 +21,7 @@ from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 # installed. To account for this, create dummy classes on import failure so we can
 # still run `isinstance()` checks.
 def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
-    class Engine(BaseDatabaseEngine):  # type: ignore[no-redef]
+    class Engine(BaseDatabaseEngine):
         def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
             raise RuntimeError(
                 f"Cannot create {name}Engine -- {module} module is not installed"
@@ -33,17 +33,17 @@ def dummy_engine(name: str, module: str) -> BaseDatabaseEngine:
 try:
     from .postgres import PostgresEngine
 except ImportError:
-    PostgresEngine = dummy_engine("PostgresEngine", "psycopg2")
+    PostgresEngine = dummy_engine("PostgresEngine", "psycopg2")  # type: ignore[misc,assignment]
 
 try:
     from .psycopg import PsycopgEngine
 except ImportError:
-    PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg")
+    PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg")  # type: ignore[misc,assignment]
 
 try:
     from .sqlite import Sqlite3Engine
 except ImportError:
-    Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3")
+    Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3")  # type: ignore[misc,assignment]
 
 
 def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py
index bf3cf94777..9dff9adbc1 100644
--- a/synapse/storage/engines/psycopg.py
+++ b/synapse/storage/engines/psycopg.py
@@ -15,7 +15,9 @@
 import logging
 from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
 
-import psycopg2.extensions
+import psycopg
+import psycopg.errors
+import psycopg.sql
 
 from synapse.storage.engines._base import (
     BaseDatabaseEngine,
@@ -31,28 +33,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
+class PsycopgEngine(BaseDatabaseEngine[psycopg.Connection]):
     def __init__(self, database_config: Mapping[str, Any]):
-        super().__init__(psycopg2, database_config)
-        psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+        super().__init__(psycopg, database_config)
+        # psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
         # actually want to use bytes than wrap it in `bytearray`.
-        def _disable_bytes_adapter(_: bytes) -> NoReturn:
-            raise Exception("Passing bytes to DB is disabled.")
+        # def _disable_bytes_adapter(_: bytes) -> NoReturn:
+        #     raise Exception("Passing bytes to DB is disabled.")
 
-        psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
+        # psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
         self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
         self._version: Optional[int] = None  # unknown as yet
 
-        self.isolation_level_map: Mapping[int, int] = {
-            IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
-            IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
-            IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
+        self.isolation_level_map: Mapping[int, psycopg.IsolationLevel] = {
+            IsolationLevel.READ_COMMITTED: psycopg.IsolationLevel.READ_COMMITTED,
+            IsolationLevel.REPEATABLE_READ: psycopg.IsolationLevel.REPEATABLE_READ,
+            IsolationLevel.SERIALIZABLE: psycopg.IsolationLevel.SERIALIZABLE,
         }
-        self.default_isolation_level = (
-            psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
-        )
+        self.default_isolation_level = psycopg.IsolationLevel.REPEATABLE_READ
         self.config = database_config
 
     @property
@@ -68,14 +68,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
 
     def check_database(
         self,
-        db_conn: psycopg2.extensions.connection,
+        db_conn: psycopg.Connection,
         allow_outdated_version: bool = False,
     ) -> None:
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
-        self._version = cast(int, db_conn.server_version)
+        self._version = cast(int, db_conn.info.server_version)
         allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
@@ -140,6 +140,9 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
             )
 
     def convert_param_style(self, sql: str) -> str:
+        if isinstance(sql, psycopg.sql.Composed):
+            return sql
+
         return sql.replace("?", "%s")
 
     def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
@@ -186,14 +189,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
         return True
 
     def is_deadlock(self, error: Exception) -> bool:
-        if isinstance(error, psycopg2.DatabaseError):
+        if isinstance(error, psycopg.errors.Error):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
             # "40001" serialization_failure
             # "40P01" deadlock_detected
-            return error.pgcode in ["40001", "40P01"]
+            return error.sqlstate in ["40001", "40P01"]
         return False
 
-    def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool:
+    def is_connection_closed(self, conn: psycopg.Connection) -> bool:
         return bool(conn.closed)
 
     def lock_table(self, txn: Cursor, table: str) -> None:
@@ -213,19 +216,19 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]):
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
-    def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
-        return conn.status != psycopg2.extensions.STATUS_READY
+    def in_transaction(self, conn: psycopg.Connection) -> bool:
+        return conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE
 
     def attempt_to_set_autocommit(
-        self, conn: psycopg2.extensions.connection, autocommit: bool
+        self, conn: psycopg.Connection, autocommit: bool
     ) -> None:
-        return conn.set_session(autocommit=autocommit)
+        conn.autocommit = autocommit
 
     def attempt_to_set_isolation_level(
-        self, conn: psycopg2.extensions.connection, isolation_level: Optional[int]
+        self, conn: psycopg.Connection, isolation_level: Optional[int]
     ) -> None:
         if isolation_level is None:
-            isolation_level = self.default_isolation_level
+            pg_isolation_level = self.default_isolation_level
         else:
-            isolation_level = self.isolation_level_map[isolation_level]
-        return conn.set_isolation_level(isolation_level)
+            pg_isolation_level = self.isolation_level_map[isolation_level]
+        conn.isolation_level = pg_isolation_level
diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
index 4856569ceb..25bb0b8395 100644
--- a/synapse/storage/schema/main/delta/69/01as_txn_seq.py
+++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
@@ -17,7 +17,7 @@
 Adds a postgres SEQUENCE for generating application service transaction IDs.
 """
 
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import PsycopgEngine
 
 
 def run_create(cur, database_engine, *args, **kwargs):
@@ -38,7 +38,14 @@ def run_create(cur, database_engine, *args, **kwargs):
 
         start_val = max(last_txn_max, txn_max) + 1
 
-        cur.execute(
-            "CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
-            (start_val,),
-        )
+        # XXX This is a hack.
+        sql = f"CREATE SEQUENCE application_services_txn_id_seq START WITH {start_val}"
+        args = ()
+
+        if isinstance(database_engine, PsycopgEngine):
+            import psycopg.sql
+            cur.execute(
+                psycopg.sql.SQL(sql).format(args)
+            )
+        else:
+            cur.execute(sql, args)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 0031df1e06..bf6b5a437b 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -33,7 +33,7 @@ class Cursor(Protocol):
     def fetchone(self) -> Optional[Tuple]:
         ...
 
-    def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
+    def fetchmany(self, size: int = ...) -> List[Tuple]:
         ...
 
     def fetchall(self) -> List[Tuple]:
@@ -42,22 +42,7 @@ class Cursor(Protocol):
     @property
     def description(
         self,
-    ) -> Optional[
-        Sequence[
-            # Note that this is an approximate typing based on sqlite3 and other
-            # drivers, and may not be entirely accurate.
-            # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
-            Tuple[
-                str,
-                Optional[Any],
-                Optional[int],
-                Optional[int],
-                Optional[int],
-                Optional[int],
-                Optional[int],
-            ]
-        ]
-    ]:
+    ) -> Optional[Sequence[Any]]:
         ...
 
     @property
diff --git a/tests/utils.py b/tests/utils.py
index 77e4b18011..5a9394c411 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -83,11 +83,11 @@ def setupdb() -> None:
 
         # Set up in the db
         db_conn = db_engine.module.connect(
-            database=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
             host=POSTGRES_HOST,
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
+            dbname=POSTGRES_BASE_DB,
         )
         logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
         prepare_database(logging_conn, db_engine, None)