diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 45962d07cc..aed1a1742e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -426,19 +426,34 @@ class LoggingTransaction:
values,
)
- def copy_write(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
- # TODO use _do_execute
- with self.txn.copy(sql) as copy:
- for record in args:
- copy.write_row(record)
+ def copy_write(
+ self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
+ ) -> None:
+ """Corresponds to a PostgreSQL COPY (...) FROM STDIN call."""
+ assert isinstance(self.database_engine, PsycopgEngine)
+
+ def f(
+ the_sql: str, the_args: Iterable[Any], the_values: Iterable[Iterable[Any]]
+ ) -> None:
+ with self.txn.copy(the_sql, the_args) as copy:
+ for record in the_values:
+ copy.write_row(record)
+
+ self._do_execute(f, sql, args, values)
def copy_read(
self, sql: str, args: Iterable[Iterable[Any]]
) -> Iterable[Tuple[Any, ...]]:
- # TODO use _do_execute
- sql = self.database_engine.convert_param_style(sql)
- with self.txn.copy(sql, args) as copy:
- yield from copy.rows()
+ """Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
+ assert isinstance(self.database_engine, PsycopgEngine)
+
+ def f(
+ the_sql: str, the_args: Iterable[Iterable[Any]]
+ ) -> Iterable[Tuple[Any, ...]]:
+ with self.txn.copy(the_sql, the_args) as copy:
+ yield from copy.rows()
+
+ return self._do_execute(f, sql, args)
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters)
@@ -1201,7 +1216,7 @@ class DatabasePool:
table,
", ".join(k for k in keys),
)
- txn.copy_write(sql, values)
+ txn.copy_write(sql, (), values)
else:
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|