diff options
author | Patrick Cloke <patrickc@matrix.org> | 2023-09-29 13:57:50 -0400 |
---|---|---|
committer | Patrick Cloke <patrickc@matrix.org> | 2023-09-29 14:07:34 -0400 |
commit | a072285e9dedd0342799aef446844af8df5f3685 (patch) | |
tree | a6b0a516256f20ada7fb3293d68da5cf09a5e68b | |
parent | Lint. (diff) | |
download | synapse-a072285e9dedd0342799aef446844af8df5f3685.tar.xz |
Use _do_execute for COPY TO/FROM.
-rw-r--r-- | synapse/storage/database.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 45962d07cc..aed1a1742e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -426,19 +426,34 @@ class LoggingTransaction: values, ) - def copy_write(self, sql: str, args: Iterable[Iterable[Any]]) -> None: - # TODO use _do_execute - with self.txn.copy(sql) as copy: - for record in args: - copy.write_row(record) + def copy_write( + self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]] + ) -> None: + """Corresponds to a PostgreSQL COPY (...) FROM STDIN call.""" + assert isinstance(self.database_engine, PsycopgEngine) + + def f( + the_sql: str, the_args: Iterable[Any], the_values: Iterable[Iterable[Any]] + ) -> None: + with self.txn.copy(the_sql, the_args) as copy: + for record in the_values: + copy.write_row(record) + + self._do_execute(f, sql, args, values) def copy_read( self, sql: str, args: Iterable[Iterable[Any]] ) -> Iterable[Tuple[Any, ...]]: - # TODO use _do_execute - sql = self.database_engine.convert_param_style(sql) - with self.txn.copy(sql, args) as copy: - yield from copy.rows() + """Corresponds to a PostgreSQL COPY (...) TO STDOUT call.""" + assert isinstance(self.database_engine, PsycopgEngine) + + def f( + the_sql: str, the_args: Iterable[Iterable[Any]] + ) -> Iterable[Tuple[Any, ...]]: + with self.txn.copy(the_sql, the_args) as copy: + yield from copy.rows() + + return self._do_execute(f, sql, args) def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: self._do_execute(self.txn.execute, sql, parameters) @@ -1201,7 +1216,7 @@ class DatabasePool: table, ", ".join(k for k in keys), ) - txn.copy_write(sql, values) + txn.copy_write(sql, (), values) else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( |