diff options
-rw-r--r-- | synapse/storage/database.py | 67 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 76 |
2 files changed, 60 insertions, 83 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index aed1a1742e..695229bc91 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import itertools import logging import time import types @@ -62,7 +63,7 @@ from synapse.storage.engines import ( BaseDatabaseEngine, Psycopg2Engine, PsycopgEngine, - Sqlite3Engine, + Sqlite3Engine, PostgresEngine, ) from synapse.storage.types import Connection, Cursor, SQLQueryParameters from synapse.util.async_helpers import delay_cancellation @@ -399,7 +400,7 @@ class LoggingTransaction: def execute_values( self, sql: str, - values: Iterable[Iterable[Any]], + values: Sequence[Sequence[Any]], template: Optional[str] = None, fetch: bool = True, ) -> List[Tuple]: @@ -412,19 +413,43 @@ class LoggingTransaction: The `template` is the snippet to merge to every item in argslist to compose the query. """ - assert isinstance(self.database_engine, Psycopg2Engine) + assert isinstance(self.database_engine, PostgresEngine) - from psycopg2.extras import execute_values + if isinstance(self.database_engine, Psycopg2Engine): - return self._do_execute( - # TODO: is it safe for values to be Iterable[Iterable[Any]] here? - # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence] - lambda the_sql, the_values: execute_values( - self.txn, the_sql, the_values, template=template, fetch=fetch - ), - sql, - values, - ) + from psycopg2.extras import execute_values + + return self._do_execute( + # TODO: is it safe for values to be Iterable[Iterable[Any]] here? + # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence] + lambda the_sql, the_values: execute_values( + self.txn, the_sql, the_values, template=template, fetch=fetch + ), + sql, + values, + ) + else: + # We use fetch = False to mean a writable query. You *might* be able + # to morph that into a COPY (...) FROM STDIN, but it isn't worth the + # effort for the few places we set fetch = False. + assert fetch is True + + # execute_values requires a single replacement, but we need to expand it + # for COPY. This assumes all inner sequences are the same length. + value_str = "(" + ", ".join("?" for _ in next(iter(values))) + ")" + sql = sql.replace("?", ", ".join(value_str for _ in values)) + + # Wrap the SQL in the COPY statement. + sql = f"COPY ({sql}) TO STDOUT" + + def f( + the_sql: str, the_args: Sequence[Sequence[Any]] + ) -> Iterable[Tuple[Any, ...]]: + with self.txn.copy(the_sql, the_args) as copy: + yield from copy.rows() + + # Flatten the values. + return self._do_execute(f, sql, list(itertools.chain.from_iterable(values))) def copy_write( self, sql: str, args: Iterable[Any], values: Iterable[Iterable[Any]] @@ -441,20 +466,6 @@ class LoggingTransaction: self._do_execute(f, sql, args, values) - def copy_read( - self, sql: str, args: Iterable[Iterable[Any]] - ) -> Iterable[Tuple[Any, ...]]: - """Corresponds to a PostgreSQL COPY (...) TO STDOUT call.""" - assert isinstance(self.database_engine, PsycopgEngine) - - def f( - the_sql: str, the_args: Iterable[Iterable[Any]] - ) -> Iterable[Tuple[Any, ...]]: - with self.txn.copy(the_sql, the_args) as copy: - yield from copy.rows() - - return self._do_execute(f, sql, args) - def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None: self._do_execute(self.txn.execute, sql, parameters) @@ -1187,7 +1198,7 @@ class DatabasePool: txn: LoggingTransaction, table: str, keys: Collection[str], - values: Iterable[Iterable[Any]], + values: Sequence[Sequence[Any]], ) -> None: """Executes an INSERT query on the named table. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6d0d8a5402..d4251be7e7 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -311,34 +311,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas results = set() if isinstance(self.database_engine, PostgresEngine): - if isinstance(self.database_engine, Psycopg2Engine): - # We can use `execute_values` to efficiently fetch the gaps when - # using postgres. - sql = """ - SELECT event_id - FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) - WHERE - c.chain_id = l.chain_id - AND sequence_number <= max_seq - """ - rows = txn.execute_values(sql, chains.items()) - else: - sql = """ - COPY ( - SELECT event_id - FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, max_seq) - WHERE - c.chain_id = l.chain_id - AND sequence_number <= max_seq - ) - TO STDOUT - """ % ( - ", ".join("(?, ?)" for _ in chains) - ) - # Flatten the arguments. - rows = txn.copy_read( - sql, list(itertools.chain.from_iterable(chains.items())) - ) + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + """ + rows = txn.execute_values(sql, chains.items()) results.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. @@ -599,38 +581,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return result if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + """ + args = [ (chain_id, min_no, max_no) for chain_id, (min_no, max_no) in chain_to_gap.items() ] - if isinstance(self.database_engine, Psycopg2Engine): - # We can use `execute_values` to efficiently fetch the gaps when - # using postgres. - sql = """ - SELECT event_id - FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) - WHERE - c.chain_id = l.chain_id - AND min_seq < sequence_number AND sequence_number <= max_seq - """ - - rows = txn.execute_values(sql, args) - else: - sql = """ - COPY ( - SELECT event_id - FROM event_auth_chains AS c, (VALUES %s) AS l(chain_id, min_seq, max_seq) - WHERE - c.chain_id = l.chain_id - AND min_seq < sequence_number AND sequence_number <= max_seq - ) - TO STDOUT - """ % ( - ", ".join("(?, ?, ?)" for _ in args) - ) - # Flatten the arguments. - rows = txn.copy_read(sql, list(itertools.chain.from_iterable(args))) + rows = txn.execute_values(sql, args) result.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. |