diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..a19d65ad23 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import (
LoggingContext,
- LoggingContextOrSentinel,
current_context,
make_deferred_yieldable,
)
@@ -50,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection
# python 3 does not have a maximum int value
@@ -180,6 +180,9 @@ class LoggingDatabaseConnection:
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+R = TypeVar("R")
+
+
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -267,6 +270,20 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
+ def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+ """Corresponds to psycopg2.extras.execute_values. Only available when
+ using postgres.
+
+ Always sets fetch=True when caling `execute_values`, so will return the
+ results.
+ """
+ assert isinstance(self.database_engine, PostgresEngine)
+ from psycopg2.extras import execute_values # type: ignore
+
+ return self._do_execute(
+ lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+ )
+
def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
@@ -277,7 +294,7 @@ class LoggingTransaction:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql: str, *args: Any) -> None:
+ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,9 +365,6 @@ class PerformanceCounters:
return top_n_counters
-R = TypeVar("R")
-
-
class DatabasePool:
"""Wraps a single physical database and connection pool.
@@ -399,6 +413,16 @@ class DatabasePool:
self._check_safe_to_upsert,
)
+ # We define this sequence here so that it can be referenced from both
+ # the DataStore and PersistEventStore.
+ def get_chain_id_txn(txn):
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+ return txn.fetchone()[0]
+
+ self.event_chain_id_gen = build_sequence_generator(
+ engine, get_chain_id_txn, "event_auth_chain_id"
+ )
+
def is_running(self) -> bool:
"""Is the database pool currently running
"""
@@ -671,12 +695,15 @@ class DatabasePool:
Returns:
The result of func
"""
- parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
- if not parent_context:
+ curr_context = current_context()
+ if not curr_context:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
+ else:
+ assert isinstance(curr_context, LoggingContext)
+ parent_context = curr_context
start_time = monotonic_time()
|