diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a19d65ad23..4646926449 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -85,8 +85,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
- """Get the connection pool for the database.
- """
+ """Get the connection pool for the database."""
# By default enable `cp_reconnect`. We need to fiddle with db_args in case
# someone has explicitly set `cp_reconnect`.
@@ -158,8 +157,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None:
self.conn.commit()
- def rollback(self, *args, **kwargs) -> None:
- self.conn.rollback(*args, **kwargs)
+ def rollback(self) -> None:
+ self.conn.rollback()
def __enter__(self) -> "Connection":
self.conn.__enter__()
@@ -244,12 +243,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
+ def fetchone(self) -> Optional[Tuple]:
+ return self.txn.fetchone()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+ return self.txn.fetchmany(size=size)
+
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
-
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@@ -262,13 +264,18 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+ """Similar to `executemany`, except `txn.rowcount` will not be correct
+ afterwards.
+
+ More efficient than `executemany` on PostgreSQL
+ """
+
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
- for val in args:
- self.execute(sql, val)
+ self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
@@ -424,8 +431,7 @@ class DatabasePool:
)
def is_running(self) -> bool:
- """Is the database pool currently running
- """
+ """Is the database pool currently running"""
return self._db_pool.running
async def _check_safe_to_upsert(self) -> None:
@@ -538,7 +544,11 @@ class DatabasePool:
# This can happen if the database disappears mid
# transaction.
transaction_logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name,
+ e,
+ i,
+ N,
)
if i < N:
i += 1
@@ -559,7 +569,9 @@ class DatabasePool:
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
- "[TXN EROLL] {%s} %s", name, e1,
+ "[TXN EROLL] {%s} %s",
+ name,
+ e1,
)
continue
raise
@@ -749,6 +761,7 @@ class DatabasePool:
Returns:
A list of dicts where the key is the column header.
"""
+ assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@@ -888,7 +901,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
- txn.executemany(sql, vals)
+ txn.execute_batch(sql, vals)
async def simple_upsert(
self,
@@ -1397,7 +1410,10 @@ class DatabasePool:
@staticmethod
def simple_select_onecol_txn(
- txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
@@ -1707,7 +1723,11 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
- desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+ desc,
+ self.simple_delete_one_txn,
+ table,
+ keyvalues,
+ db_autocommit=True,
)
@staticmethod
|