summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-09-29 13:57:50 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-09-29 14:07:34 -0400
commita072285e9dedd0342799aef446844af8df5f3685 (patch)
treea6b0a516256f20ada7fb3293d68da5cf09a5e68b
parentLint. (diff)
downloadsynapse-a072285e9dedd0342799aef446844af8df5f3685.tar.xz
Use _do_execute for COPY TO/FROM.
-rw-r--r--synapse/storage/database.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 45962d07cc..aed1a1742e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -426,19 +426,34 @@ class LoggingTransaction:
             values,
         )
 
-    def copy_write(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
-        # TODO use _do_execute
-        with self.txn.copy(sql) as copy:
-            for record in args:
-                copy.write_row(record)
+    def copy_write(
+        self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
+    ) -> None:
+        """Corresponds to a PostgreSQL COPY (...) FROM STDIN call."""
+        assert isinstance(self.database_engine, PsycopgEngine)
+
+        def f(
+            the_sql: str, the_args: Iterable[Any], the_values: Iterable[Iterable[Any]]
+        ) -> None:
+            with self.txn.copy(the_sql, the_args) as copy:
+                for record in the_values:
+                    copy.write_row(record)
+
+        self._do_execute(f, sql, args, values)
 
     def copy_read(
         self, sql: str, args: Iterable[Iterable[Any]]
     ) -> Iterable[Tuple[Any, ...]]:
-        # TODO use _do_execute
-        sql = self.database_engine.convert_param_style(sql)
-        with self.txn.copy(sql, args) as copy:
-            yield from copy.rows()
+        """Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
+        assert isinstance(self.database_engine, PsycopgEngine)
+
+        def f(
+            the_sql: str, the_args: Iterable[Iterable[Any]]
+        ) -> Iterable[Tuple[Any, ...]]:
+            with self.txn.copy(the_sql, the_args) as copy:
+                yield from copy.rows()
+
+        return self._do_execute(f, sql, args)
 
     def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
         self._do_execute(self.txn.execute, sql, parameters)
@@ -1201,7 +1216,7 @@ class DatabasePool:
                 table,
                 ", ".join(k for k in keys),
             )
-            txn.copy_write(sql, values)
+            txn.copy_write(sql, (), values)
 
         else:
             sql = "INSERT INTO %s (%s) VALUES(%s)" % (