summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py44
1 files changed, 37 insertions, 7 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b70ca3087b..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import Collection
 
 # python 3 does not have a maximum int value
@@ -179,6 +180,9 @@ class LoggingDatabaseConnection:
 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
 
 
+R = TypeVar("R")
+
+
 class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -258,13 +262,32 @@ class LoggingTransaction:
         return self.txn.description
 
     def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+        """Similar to `executemany`, except `txn.rowcount` will not be correct
+        afterwards.
+
+        More efficient than `executemany` on PostgreSQL
+        """
+
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
-            for val in args:
-                self.execute(sql, val)
+            self.executemany(sql, args)
+
+    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+        """Corresponds to psycopg2.extras.execute_values. Only available when
+        using postgres.
+
+        Always sets fetch=True when caling `execute_values`, so will return the
+        results.
+        """
+        assert isinstance(self.database_engine, PostgresEngine)
+        from psycopg2.extras import execute_values  # type: ignore
+
+        return self._do_execute(
+            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+        )
 
     def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
@@ -276,7 +299,7 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql: str, *args: Any) -> None:
+    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -347,9 +370,6 @@ class PerformanceCounters:
         return top_n_counters
 
 
-R = TypeVar("R")
-
-
 class DatabasePool:
     """Wraps a single physical database and connection pool.
 
@@ -398,6 +418,16 @@ class DatabasePool:
                 self._check_safe_to_upsert,
             )
 
+        # We define this sequence here so that it can be referenced from both
+        # the DataStore and PersistEventStore.
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self.event_chain_id_gen = build_sequence_generator(
+            engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
     def is_running(self) -> bool:
         """Is the database pool currently running
         """
@@ -863,7 +893,7 @@ class DatabasePool:
             ", ".join("?" for _ in keys[0]),
         )
 
-        txn.executemany(sql, vals)
+        txn.execute_batch(sql, vals)
 
     async def simple_upsert(
         self,