summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/database.py67
-rw-r--r--synapse/storage/databases/main/event_federation.py76
2 files changed, 60 insertions, 83 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index aed1a1742e..695229bc91 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import inspect
+import itertools
 import logging
 import time
 import types
@@ -62,7 +63,7 @@ from synapse.storage.engines import (
     BaseDatabaseEngine,
     Psycopg2Engine,
     PsycopgEngine,
-    Sqlite3Engine,
+    Sqlite3Engine, PostgresEngine,
 )
 from synapse.storage.types import Connection, Cursor, SQLQueryParameters
 from synapse.util.async_helpers import delay_cancellation
@@ -399,7 +400,7 @@ class LoggingTransaction:
     def execute_values(
         self,
         sql: str,
-        values: Iterable[Iterable[Any]],
+        values: Sequence[Sequence[Any]],
         template: Optional[str] = None,
         fetch: bool = True,
     ) -> List[Tuple]:
@@ -412,19 +413,43 @@ class LoggingTransaction:
         The `template` is the snippet to merge to every item in argslist to
         compose the query.
         """
-        assert isinstance(self.database_engine, Psycopg2Engine)
+        assert isinstance(self.database_engine, PostgresEngine)
 
-        from psycopg2.extras import execute_values
+        if isinstance(self.database_engine, Psycopg2Engine):
 
-        return self._do_execute(
-            # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
-            # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
-            lambda the_sql, the_values: execute_values(
-                self.txn, the_sql, the_values, template=template, fetch=fetch
-            ),
-            sql,
-            values,
-        )
+            from psycopg2.extras import execute_values
+
+            return self._do_execute(
+                # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
+                # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
+                lambda the_sql, the_values: execute_values(
+                    self.txn, the_sql, the_values, template=template, fetch=fetch
+                ),
+                sql,
+                values,
+            )
+        else:
+            # We use fetch = False to mean a writable query. You *might* be able
+            # to morph that into a COPY (...) FROM STDIN, but it isn't worth the
+            # effort for the few places we set fetch = False.
+            assert fetch is True
+
+            # execute_values requires a single replacement, but we need to expand it
+            # for COPY. This assumes all inner sequences are the same length.
+            value_str = "(" + ", ".join("?" for _ in next(iter(values))) + ")"
+            sql = sql.replace("?", ", ".join(value_str for _ in values))
+
+            # Wrap the SQL in the COPY statement.
+            sql = f"COPY ({sql}) TO STDOUT"
+
+            def f(
+                the_sql: str, the_args: Sequence[Sequence[Any]]
+            ) -> Iterable[Tuple[Any, ...]]:
+                with self.txn.copy(the_sql, the_args) as copy:
+                    yield from copy.rows()
+
+            # Flatten the values.
+            return self._do_execute(f, sql, list(itertools.chain.from_iterable(values)))
 
     def copy_write(
         self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]]
@@ -441,20 +466,6 @@ class LoggingTransaction:
 
         self._do_execute(f, sql, args, values)
 
-    def copy_read(
-        self, sql: str, args: Iterable[Iterable[Any]]
-    ) -> Iterable[Tuple[Any, ...]]:
-        """Corresponds to a PostgreSQL COPY (...) TO STDOUT call."""
-        assert isinstance(self.database_engine, PsycopgEngine)
-
-        def f(
-            the_sql: str, the_args: Iterable[Iterable[Any]]
-        ) -> Iterable[Tuple[Any, ...]]:
-            with self.txn.copy(the_sql, the_args) as copy:
-                yield from copy.rows()
-
-        return self._do_execute(f, sql, args)
-
     def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
         self._do_execute(self.txn.execute, sql, parameters)
 
@@ -1187,7 +1198,7 @@ class DatabasePool:
         txn: LoggingTransaction,
         table: str,
         keys: Collection[str],
-        values: Iterable[Iterable[Any]],
+        values: Sequence[Sequence[Any]],
     ) -> None:
         """Executes an INSERT query on the named table.
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d0d8a5402..d4251be7e7 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -311,34 +311,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             results = set()
 
         if isinstance(self.database_engine, PostgresEngine):
-            if isinstance(self.database_engine, Psycopg2Engine):
-                # We can use `execute_values` to efficiently fetch the gaps when
-                # using postgres.
-                sql = """
-                    SELECT event_id
-                    FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
-                    WHERE
-                        c.chain_id = l.chain_id
-                        AND sequence_number <= max_seq
-                """
-                rows = txn.execute_values(sql, chains.items())
-            else:
-                sql = """
-                COPY (
-                    SELECT event_id
-                    FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq)
-                    WHERE
-                        c.chain_id = l.chain_id
-                        AND sequence_number <= max_seq
-                    )
-                TO STDOUT
-                """ % (
-                    ", ".join("(?, ?)" for _ in chains)
-                )
-                # Flatten the arguments.
-                rows = txn.copy_read(
-                    sql, list(itertools.chain.from_iterable(chains.items()))
-                )
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND sequence_number <= max_seq
+            """
+            rows = txn.execute_values(sql, chains.items())
             results.update(r for r, in rows)
         else:
             # For SQLite we just fall back to doing a noddy for loop.
@@ -599,38 +581,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             return result
 
         if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND min_seq < sequence_number AND sequence_number <= max_seq
+            """
+
             args = [
                 (chain_id, min_no, max_no)
                 for chain_id, (min_no, max_no) in chain_to_gap.items()
             ]
 
-            if isinstance(self.database_engine, Psycopg2Engine):
-                # We can use `execute_values` to efficiently fetch the gaps when
-                # using postgres.
-                sql = """
-                    SELECT event_id
-                    FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
-                    WHERE
-                        c.chain_id = l.chain_id
-                        AND min_seq < sequence_number AND sequence_number <= max_seq
-                """
-
-                rows = txn.execute_values(sql, args)
-            else:
-                sql = """
-                COPY (
-                    SELECT event_id
-                    FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq)
-                    WHERE
-                        c.chain_id = l.chain_id
-                        AND min_seq < sequence_number AND sequence_number <= max_seq
-                    )
-                TO STDOUT
-                """ % (
-                    ", ".join("(?, ?, ?)" for _ in args)
-                )
-                # Flatten the arguments.
-                rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args)))
+            rows = txn.execute_values(sql, args)
             result.update(r for r, in rows)
         else:
             # For SQLite we just fall back to doing a noddy for loop.