diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index aed1a1742e..695229bc91 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
+import itertools
import logging
import time
import types
@@ -62,7 +63,7 @@ from synapse.storage.engines import (
BaseDatabaseEngine,
Psycopg2Engine,
PsycopgEngine,
- Sqlite3Engine,
+ Sqlite3Engine, PostgresEngine,
)
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.util.async_helpers import delay_cancellation
@@ -399,7 +400,7 @@ class LoggingTransaction:
def execute_values(
self,
sql: str,
- values: Iterable[Iterable[Any]],
+ values: Sequence[Sequence[Any]],
template: Optional[str] = None,
fetch: bool = True,
) -> List[Tuple]:
@@ -412,19 +413,43 @@ class LoggingTransaction:
The `template` is the snippet to merge to every item in argslist to
compose the query.
"""
- assert isinstance(self.database_engine, Psycopg2Engine)
+ assert isinstance(self.database_engine, PostgresEngine)
- from psycopg2.extras import execute_values
+ if isinstance(self.database_engine, Psycopg2Engine):
- return self._do_execute(
- # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
- # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
- lambda the_sql, the_values: execute_values(
- self.txn, the_sql, the_values, template=template, fetch=fetch
- ),
- sql,
- values,
- )
+ from psycopg2.extras import execute_values
+
+ return self._do_execute(
+ # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
+ # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
+ lambda the_sql, the_values: execute_values(
+ self.txn, the_sql, the_values, template=template, fetch=fetch
+ ),
+ sql,
+ values,
+ )
+ else:
+ # We use fetch = False to mean a writable query. You *might* be able
+ # to morph that into a COPY (...) FROM STDIN, but it isn't worth the
+ # effort for the few places we set fetch = False.
+ assert fetch is True
+
+ # execute_values requires a single replacement, but we need to expand it
+ # for COPY. This assumes all inner sequences are the same length.
+ value_str = "(" + ", ".join("?" for _ in next(iter(values))) + ")"
+ sql = sql.replace("?", ", ".join(value_str for _ in values))
+
+ # Wrap the SQL in the COPY statement.
+ sql = f"COPY ({sql}) TO STDOUT"
+
+ def f(
+ the_sql: str, the_args: Sequence[Sequence[Any]]
+ ) -> Iterable[Tuple[Any, ...]]:
+ with self.txn.copy(the_sql, the_args) as copy:
+ yield from copy.rows()
+
+ # Flatten the values.
+ return self._do_execute(f, sql, list(itertools.chain.from_iterable(values)))
def copy_write(
self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
@@ -441,20 +466,6 @@ class LoggingTransaction:
self._do_execute(f, sql, args, values)
- def copy_read(
- self, sql: str, args: Iterable[Iterable[Any]]
- ) -> Iterable[Tuple[Any, ...]]:
- """Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
- assert isinstance(self.database_engine, PsycopgEngine)
-
- def f(
- the_sql: str, the_args: Iterable[Iterable[Any]]
- ) -> Iterable[Tuple[Any, ...]]:
- with self.txn.copy(the_sql, the_args) as copy:
- yield from copy.rows()
-
- return self._do_execute(f, sql, args)
-
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters)
@@ -1187,7 +1198,7 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
keys: Collection[str],
- values: Iterable[Iterable[Any]],
+ values: Sequence[Sequence[Any]],
) -> None:
"""Executes an INSERT query on the named table.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d0d8a5402..d4251be7e7 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -311,34 +311,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
results = set()
if isinstance(self.database_engine, PostgresEngine):
- if isinstance(self.database_engine, Psycopg2Engine):
- # We can use `execute_values` to efficiently fetch the gaps when
- # using postgres.
- sql = """
- SELECT event_id
- FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
- WHERE
- c.chain_id = l.chain_id
- AND sequence_number <= max_seq
- """
- rows = txn.execute_values(sql, chains.items())
- else:
- sql = """
- COPY (
- SELECT event_id
- FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq)
- WHERE
- c.chain_id = l.chain_id
- AND sequence_number <= max_seq
- )
- TO STDOUT
- """ % (
- ", ".join("(?, ?)" for _ in chains)
- )
- # Flatten the arguments.
- rows = txn.copy_read(
- sql, list(itertools.chain.from_iterable(chains.items()))
- )
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND sequence_number <= max_seq
+ """
+ rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
@@ -599,38 +581,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND min_seq < sequence_number AND sequence_number <= max_seq
+ """
+
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
]
- if isinstance(self.database_engine, Psycopg2Engine):
- # We can use `execute_values` to efficiently fetch the gaps when
- # using postgres.
- sql = """
- SELECT event_id
- FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
- WHERE
- c.chain_id = l.chain_id
- AND min_seq < sequence_number AND sequence_number <= max_seq
- """
-
- rows = txn.execute_values(sql, args)
- else:
- sql = """
- COPY (
- SELECT event_id
- FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq)
- WHERE
- c.chain_id = l.chain_id
- AND min_seq < sequence_number AND sequence_number <= max_seq
- )
- TO STDOUT
- """ % (
- ", ".join("(?, ?, ?)" for _ in args)
- )
- # Flatten the arguments.
- rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
+ rows = txn.execute_values(sql, args)
result.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
|