summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-01-11 16:09:22 +0000
committerGitHub <noreply@github.com>2021-01-11 16:09:22 +0000
commit1315a2e8be702a513d49c1142e9e52b642286635 (patch)
tree2c9aca9e27a2fd4ac1dda844015cefb26a021939 /synapse/storage/database.py
parentClean up exception handling in the startup code (#9059) (diff)
downloadsynapse-1315a2e8be702a513d49c1142e9e52b642286635.tar.xz
Use a chain cover index to efficiently calculate auth chain difference (#8868)
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b70ca3087b..6cfadc2b4e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -179,6 +179,9 @@ class LoggingDatabaseConnection:
 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
 
 
+R = TypeVar("R")
+
+
 class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -266,6 +269,20 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
+    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+        """Corresponds to psycopg2.extras.execute_values. Only available when
+        using postgres.
+
+        Always sets fetch=True when caling `execute_values`, so will return the
+        results.
+        """
+        assert isinstance(self.database_engine, PostgresEngine)
+        from psycopg2.extras import execute_values  # type: ignore
+
+        return self._do_execute(
+            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+        )
+
     def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -276,7 +293,7 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql: str, *args: Any) -> None:
+    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -347,9 +364,6 @@ class PerformanceCounters:
         return top_n_counters
 
 
-R = TypeVar("R")
-
-
 class DatabasePool:
     """Wraps a single physical database and connection pool.