diff options
-rw-r--r-- | synapse/storage/database.py | 12 | ||||
-rw-r--r-- | synapse/storage/engines/__init__.py | 8 | ||||
-rw-r--r-- | synapse/storage/engines/psycopg.py | 57 | ||||
-rw-r--r-- | synapse/storage/schema/main/delta/69/01as_txn_seq.py | 17 | ||||
-rw-r--r-- | synapse/storage/types.py | 19 | ||||
-rw-r--r-- | tests/utils.py | 2 |
6 files changed, 58 insertions, 57 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 55bcb90001..d78898370b 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -56,7 +56,7 @@ from synapse.logging.context import ( from synapse.metrics import register_threadpool from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, PsycopgEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter @@ -334,7 +334,8 @@ class LoggingTransaction: def fetchone(self) -> Optional[Tuple]: return self.txn.fetchone() - def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: + def fetchmany(self, size: int = 0) -> List[Tuple]: + # XXX This can also be called with no arguments. return self.txn.fetchmany(size=size) def fetchall(self) -> List[Tuple]: @@ -400,6 +401,11 @@ class LoggingTransaction: def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" + if isinstance(self.database_engine, PsycopgEngine): + import psycopg.sql + if isinstance(sql, psycopg.sql.Composed): + return sql.as_string(None) + return " ".join(line.strip() for line in sql.splitlines() if line.strip()) def _do_execute( @@ -440,7 +446,7 @@ class LoggingTransaction: finally: secs = time.time() - start sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) + sql_query_timer.labels(one_line_sql.split()[0]).observe(secs) def close(self) -> None: self.txn.close() diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index a9fdcdcef7..8038dce595 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -21,7 +21,7 @@ from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup # installed. To account for this, create dummy classes on import failure so we can # still run `isinstance()` checks. def dummy_engine(name: str, module: str) -> BaseDatabaseEngine: - class Engine(BaseDatabaseEngine): # type: ignore[no-redef] + class Engine(BaseDatabaseEngine): def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc] raise RuntimeError( f"Cannot create {name}Engine -- {module} module is not installed" @@ -33,17 +33,17 @@ def dummy_engine(name: str, module: str) -> BaseDatabaseEngine: try: from .postgres import PostgresEngine except ImportError: - PostgresEngine = dummy_engine("PostgresEngine", "psycopg2") + PostgresEngine = dummy_engine("PostgresEngine", "psycopg2") # type: ignore[misc,assignment] try: from .psycopg import PsycopgEngine except ImportError: - PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg") + PsycopgEngine = dummy_engine("PsycopgEngine", "psycopg") # type: ignore[misc,assignment] try: from .sqlite import Sqlite3Engine except ImportError: - Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3") + Sqlite3Engine = dummy_engine("Sqlite3Engine", "sqlite3") # type: ignore[misc,assignment] def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine: diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py index bf3cf94777..9dff9adbc1 100644 --- a/synapse/storage/engines/psycopg.py +++ b/synapse/storage/engines/psycopg.py @@ -15,7 +15,9 @@ import logging from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast -import psycopg2.extensions +import psycopg +import psycopg.errors +import psycopg.sql from synapse.storage.engines._base import ( BaseDatabaseEngine, @@ -31,28 +33,26 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): +class PsycopgEngine(BaseDatabaseEngine[psycopg.Connection]): def __init__(self, database_config: Mapping[str, Any]): - super().__init__(psycopg2, database_config) - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) + super().__init__(psycopg, database_config) + # psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do # actually want to use bytes than wrap it in `bytearray`. - def _disable_bytes_adapter(_: bytes) -> NoReturn: - raise Exception("Passing bytes to DB is disabled.") + # def _disable_bytes_adapter(_: bytes) -> NoReturn: + # raise Exception("Passing bytes to DB is disabled.") - psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter) + # psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter) self.synchronous_commit: bool = database_config.get("synchronous_commit", True) self._version: Optional[int] = None # unknown as yet - self.isolation_level_map: Mapping[int, int] = { - IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED, - IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ, - IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE, + self.isolation_level_map: Mapping[int, psycopg.IsolationLevel] = { + IsolationLevel.READ_COMMITTED: psycopg.IsolationLevel.READ_COMMITTED, + IsolationLevel.REPEATABLE_READ: psycopg.IsolationLevel.REPEATABLE_READ, + IsolationLevel.SERIALIZABLE: psycopg.IsolationLevel.SERIALIZABLE, } - self.default_isolation_level = ( - psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ - ) + self.default_isolation_level = psycopg.IsolationLevel.REPEATABLE_READ self.config = database_config @property @@ -68,14 +68,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): def check_database( self, - db_conn: psycopg2.extensions.connection, + db_conn: psycopg.Connection, allow_outdated_version: bool = False, ) -> None: # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them # together. For example, version 8.1.5 will be returned as 80105 - self._version = cast(int, db_conn.server_version) + self._version = cast(int, db_conn.info.server_version) allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? @@ -140,6 +140,9 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): ) def convert_param_style(self, sql: str) -> str: + if isinstance(sql, psycopg.sql.Composed): + return sql + return sql.replace("?", "%s") def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: @@ -186,14 +189,14 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): return True def is_deadlock(self, error: Exception) -> bool: - if isinstance(error, psycopg2.DatabaseError): + if isinstance(error, psycopg.errors.Error): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html # "40001" serialization_failure # "40P01" deadlock_detected - return error.pgcode in ["40001", "40P01"] + return error.sqlstate in ["40001", "40P01"] return False - def is_connection_closed(self, conn: psycopg2.extensions.connection) -> bool: + def is_connection_closed(self, conn: psycopg.Connection) -> bool: return bool(conn.closed) def lock_table(self, txn: Cursor, table: str) -> None: @@ -213,19 +216,19 @@ class PsycopgEngine(BaseDatabaseEngine[psycopg2.extensions.connection]): else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) - def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: - return conn.status != psycopg2.extensions.STATUS_READY + def in_transaction(self, conn: psycopg.Connection) -> bool: + return conn.info.transaction_status != psycopg.pq.TransactionStatus.IDLE def attempt_to_set_autocommit( - self, conn: psycopg2.extensions.connection, autocommit: bool + self, conn: psycopg.Connection, autocommit: bool ) -> None: - return conn.set_session(autocommit=autocommit) + conn.autocommit = autocommit def attempt_to_set_isolation_level( - self, conn: psycopg2.extensions.connection, isolation_level: Optional[int] + self, conn: psycopg.Connection, isolation_level: Optional[int] ) -> None: if isolation_level is None: - isolation_level = self.default_isolation_level + pg_isolation_level = self.default_isolation_level else: - isolation_level = self.isolation_level_map[isolation_level] - return conn.set_isolation_level(isolation_level) + pg_isolation_level = self.isolation_level_map[isolation_level] + conn.isolation_level = pg_isolation_level diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py index 4856569ceb..25bb0b8395 100644 --- a/synapse/storage/schema/main/delta/69/01as_txn_seq.py +++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py @@ -17,7 +17,7 @@ Adds a postgres SEQUENCE for generating application service transaction IDs. """ -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import PsycopgEngine def run_create(cur, database_engine, *args, **kwargs): @@ -38,7 +38,14 @@ def run_create(cur, database_engine, *args, **kwargs): start_val = max(last_txn_max, txn_max) + 1 - cur.execute( - "CREATE SEQUENCE application_services_txn_id_seq START WITH %s", - (start_val,), - ) + # XXX This is a hack. + sql = f"CREATE SEQUENCE application_services_txn_id_seq START WITH {start_val}" + args = () + + if isinstance(database_engine, PsycopgEngine): + import psycopg.sql + cur.execute( + psycopg.sql.SQL(sql).format(args) + ) + else: + cur.execute(sql, args) diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 0031df1e06..bf6b5a437b 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -33,7 +33,7 @@ class Cursor(Protocol): def fetchone(self) -> Optional[Tuple]: ... - def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: + def fetchmany(self, size: int = ...) -> List[Tuple]: ... def fetchall(self) -> List[Tuple]: @@ -42,22 +42,7 @@ class Cursor(Protocol): @property def description( self, - ) -> Optional[ - Sequence[ - # Note that this is an approximate typing based on sqlite3 and other - # drivers, and may not be entirely accurate. - # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description - Tuple[ - str, - Optional[Any], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - ] - ] - ]: + ) -> Optional[Sequence[Any]]: ... @property diff --git a/tests/utils.py b/tests/utils.py index 77e4b18011..5a9394c411 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -83,11 +83,11 @@ def setupdb() -> None: # Set up in the db db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, port=POSTGRES_PORT, password=POSTGRES_PASSWORD, + dbname=POSTGRES_BASE_DB, ) logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(logging_conn, db_engine, None) |